Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
cbf7820f
Commit
cbf7820f
authored
Apr 21, 2025
by
helloyongyang
Browse files
support _ProfilingContext and _NullContext for speed test
parent
75c03057
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
74 additions
and
42 deletions
+74
-42
lightx2v/__main__.py
lightx2v/__main__.py
+24
-42
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+39
-0
scripts/run_hunyuan_i2v.sh
scripts/run_hunyuan_i2v.sh
+2
-0
scripts/run_hunyuan_t2v.sh
scripts/run_hunyuan_t2v.sh
+2
-0
scripts/run_hunyuan_t2v_dist.sh
scripts/run_hunyuan_t2v_dist.sh
+1
-0
scripts/run_hunyuan_t2v_taylorseer.sh
scripts/run_hunyuan_t2v_taylorseer.sh
+1
-0
scripts/run_wan_i2v.sh
scripts/run_wan_i2v.sh
+1
-0
scripts/run_wan_i2v_with_lora.sh
scripts/run_wan_i2v_with_lora.sh
+2
-0
scripts/run_wan_t2v.sh
scripts/run_wan_t2v.sh
+1
-0
scripts/run_wan_t2v_dist.sh
scripts/run_wan_t2v_dist.sh
+1
-0
No files found.
lightx2v/__main__.py
View file @
cbf7820f
...
@@ -8,8 +8,12 @@ import json
...
@@ -8,8 +8,12 @@ import json
import
torchvision
import
torchvision
import
torchvision.transforms.functional
as
TF
import
torchvision.transforms.functional
as
TF
import
numpy
as
np
import
numpy
as
np
from
contextlib
import
contextmanager
from
PIL
import
Image
from
PIL
import
Image
from
lightx2v.utils.utils
import
save_videos_grid
,
seed_all
,
cache_video
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.set_config
import
set_config
from
lightx2v.models.input_encoders.hf.llama.model
import
TextEncoderHFLlamaModel
from
lightx2v.models.input_encoders.hf.llama.model
import
TextEncoderHFLlamaModel
from
lightx2v.models.input_encoders.hf.clip.model
import
TextEncoderHFClipModel
from
lightx2v.models.input_encoders.hf.clip.model
import
TextEncoderHFClipModel
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
...
@@ -27,19 +31,8 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
...
@@ -27,19 +31,8 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from
lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.utils.utils
import
save_videos_grid
,
seed_all
,
cache_video
from
lightx2v.common.ops
import
*
from
lightx2v.utils.set_config
import
set_config
from
lightx2v.common.ops
import
*
@
contextmanager
def
time_duration
(
label
:
str
=
""
):
torch
.
cuda
.
synchronize
()
start_time
=
time
.
time
()
yield
torch
.
cuda
.
synchronize
()
end_time
=
time
.
time
()
print
(
f
"==>
{
label
}
start:
{
time
.
strftime
(
'%Y-%m-%d %H:%M:%S'
,
time
.
localtime
(
start_time
))
}
cost
{
end_time
-
start_time
:.
2
f
}
seconds"
)
def
load_models
(
config
):
def
load_models
(
config
):
...
@@ -63,7 +56,7 @@ def load_models(config):
...
@@ -63,7 +56,7 @@ def load_models(config):
vae_model
=
VideoEncoderKLCausal3DModel
(
config
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
,
config
=
config
)
vae_model
=
VideoEncoderKLCausal3DModel
(
config
.
model_path
,
dtype
=
torch
.
float16
,
device
=
init_device
,
config
=
config
)
elif
config
.
model_cls
==
"wan2.1"
:
elif
config
.
model_cls
==
"wan2.1"
:
with
time_duration
(
"Load Text Encoder"
):
with
ProfilingContext
(
"Load Text Encoder"
):
text_encoder
=
T5EncoderModel
(
text_encoder
=
T5EncoderModel
(
text_len
=
config
[
"text_len"
],
text_len
=
config
[
"text_len"
],
dtype
=
torch
.
bfloat16
,
dtype
=
torch
.
bfloat16
,
...
@@ -73,20 +66,20 @@ def load_models(config):
...
@@ -73,20 +66,20 @@ def load_models(config):
shard_fn
=
None
,
shard_fn
=
None
,
)
)
text_encoders
=
[
text_encoder
]
text_encoders
=
[
text_encoder
]
with
time_duration
(
"Load Wan Model"
):
with
ProfilingContext
(
"Load Wan Model"
):
model
=
WanModel
(
config
.
model_path
,
config
,
init_device
)
model
=
WanModel
(
config
.
model_path
,
config
,
init_device
)
if
config
.
lora_path
:
if
config
.
lora_path
:
lora_wrapper
=
WanLoraWrapper
(
model
)
lora_wrapper
=
WanLoraWrapper
(
model
)
with
time_duration
(
"Load LoRA Model"
):
with
ProfilingContext
(
"Load LoRA Model"
):
lora_name
=
lora_wrapper
.
load_lora
(
config
.
lora_path
)
lora_name
=
lora_wrapper
.
load_lora
(
config
.
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
config
.
strength_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
config
.
strength_model
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
print
(
f
"Loaded LoRA:
{
lora_name
}
"
)
with
time_duration
(
"Load WAN VAE Model"
):
with
ProfilingContext
(
"Load WAN VAE Model"
):
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
config
.
parallel_vae
)
vae_model
=
WanVAE
(
vae_pth
=
os
.
path
.
join
(
config
.
model_path
,
"Wan2.1_VAE.pth"
),
device
=
init_device
,
parallel
=
config
.
parallel_vae
)
if
config
.
task
==
"i2v"
:
if
config
.
task
==
"i2v"
:
with
time_duration
(
"Load Image Encoder"
):
with
ProfilingContext
(
"Load Image Encoder"
):
image_encoder
=
CLIPModel
(
image_encoder
=
CLIPModel
(
dtype
=
torch
.
float16
,
dtype
=
torch
.
float16
,
device
=
init_device
,
device
=
init_device
,
...
@@ -280,27 +273,16 @@ def init_scheduler(config, image_encoder_output):
...
@@ -280,27 +273,16 @@ def init_scheduler(config, image_encoder_output):
def
run_main_inference
(
model
,
inputs
):
def
run_main_inference
(
model
,
inputs
):
for
step_index
in
range
(
model
.
scheduler
.
infer_steps
):
for
step_index
in
range
(
model
.
scheduler
.
infer_steps
):
torch
.
cuda
.
synchronize
()
print
(
f
"==> step_index:
{
step_index
+
1
}
/
{
model
.
scheduler
.
infer_steps
}
"
)
time1
=
time
.
time
()
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
torch
.
cuda
.
synchronize
()
time2
=
time
.
time
()
model
.
infer
(
inputs
)
torch
.
cuda
.
synchronize
()
time3
=
time
.
time
()
model
.
scheduler
.
step_post
()
with
ProfilingContext4Debug
(
"step_pre"
):
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
torch
.
cuda
.
synchronize
()
with
ProfilingContext4Debug
(
"infer"
):
time4
=
time
.
time
(
)
model
.
infer
(
inputs
)
print
(
f
"step
{
step_index
}
infer time:
{
time3
-
time2
}
"
)
with
ProfilingContext4Debug
(
"step_post"
):
print
(
f
"step
{
step_index
}
all time:
{
time4
-
time1
}
"
)
model
.
scheduler
.
step_post
()
print
(
"*"
*
10
)
return
model
.
scheduler
.
latents
,
model
.
scheduler
.
generator
return
model
.
scheduler
.
latents
,
model
.
scheduler
.
generator
...
@@ -344,7 +326,7 @@ if __name__ == "__main__":
...
@@ -344,7 +326,7 @@ if __name__ == "__main__":
parser
.
add_argument
(
"--strength_model"
,
type
=
float
,
default
=
1.0
)
parser
.
add_argument
(
"--strength_model"
,
type
=
float
,
default
=
1.0
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
start_time
=
time
.
time
()
start_time
=
time
.
perf_counter
()
print
(
f
"args:
{
args
}
"
)
print
(
f
"args:
{
args
}
"
)
seed_all
(
args
.
seed
)
seed_all
(
args
.
seed
)
...
@@ -356,7 +338,7 @@ if __name__ == "__main__":
...
@@ -356,7 +338,7 @@ if __name__ == "__main__":
print
(
f
"config:
{
config
}
"
)
print
(
f
"config:
{
config
}
"
)
with
time_duration
(
"Load models"
):
with
ProfilingContext
(
"Load models"
):
model
,
text_encoders
,
vae_model
,
image_encoder
=
load_models
(
config
)
model
,
text_encoders
,
vae_model
,
image_encoder
=
load_models
(
config
)
if
config
[
"task"
]
in
[
"i2v"
]:
if
config
[
"task"
]
in
[
"i2v"
]:
...
@@ -364,7 +346,7 @@ if __name__ == "__main__":
...
@@ -364,7 +346,7 @@ if __name__ == "__main__":
else
:
else
:
image_encoder_output
=
{
"clip_encoder_out"
:
None
,
"vae_encode_out"
:
None
}
image_encoder_output
=
{
"clip_encoder_out"
:
None
,
"vae_encode_out"
:
None
}
with
time_duration
(
"Run Text Encoder"
):
with
ProfilingContext
(
"Run Text Encoder"
):
text_encoder_output
=
run_text_encoder
(
config
[
"prompt"
],
text_encoders
,
config
,
image_encoder_output
)
text_encoder_output
=
run_text_encoder
(
config
[
"prompt"
],
text_encoders
,
config
,
image_encoder_output
)
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
inputs
=
{
"text_encoder_output"
:
text_encoder_output
,
"image_encoder_output"
:
image_encoder_output
}
...
@@ -383,15 +365,15 @@ if __name__ == "__main__":
...
@@ -383,15 +365,15 @@ if __name__ == "__main__":
del
text_encoder_output
,
image_encoder_output
,
model
,
text_encoders
,
scheduler
del
text_encoder_output
,
image_encoder_output
,
model
,
text_encoders
,
scheduler
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
with
time_duration
(
"Run VAE"
):
with
ProfilingContext
(
"Run VAE"
):
images
=
run_vae
(
latents
,
generator
,
config
)
images
=
run_vae
(
latents
,
generator
,
config
)
if
not
config
.
parallel_attn_type
or
(
config
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
if
not
config
.
parallel_attn_type
or
(
config
.
parallel_attn_type
and
dist
.
get_rank
()
==
0
):
with
time_duration
(
"Save video"
):
with
ProfilingContext
(
"Save video"
):
if
config
.
model_cls
==
"wan2.1"
:
if
config
.
model_cls
==
"wan2.1"
:
cache_video
(
tensor
=
images
,
save_file
=
config
.
save_video_path
,
fps
=
16
,
nrow
=
1
,
normalize
=
True
,
value_range
=
(
-
1
,
1
))
cache_video
(
tensor
=
images
,
save_file
=
config
.
save_video_path
,
fps
=
16
,
nrow
=
1
,
normalize
=
True
,
value_range
=
(
-
1
,
1
))
else
:
else
:
save_videos_grid
(
images
,
config
.
save_video_path
,
fps
=
24
)
save_videos_grid
(
images
,
config
.
save_video_path
,
fps
=
24
)
end_time
=
time
.
time
()
end_time
=
time
.
perf_counter
()
print
(
f
"Total cost:
{
end_time
-
start_time
}
"
)
print
(
f
"Total cost:
{
end_time
-
start_time
}
"
)
lightx2v/utils/profiler.py
0 → 100644
View file @
cbf7820f
import
time
import
os
import
torch
from
contextlib
import
ContextDecorator
ENABLE_PROFILING_DEBUG
=
os
.
getenv
(
"ENABLE_PROFILING_DEBUG"
,
"false"
).
lower
()
==
"true"
class
_ProfilingContext
(
ContextDecorator
):
def
__init__
(
self
,
name
):
self
.
name
=
name
def
__enter__
(
self
):
torch
.
cuda
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch
.
cuda
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
print
(
f
"[Profile]
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds"
)
return
False
class
_NullContext
(
ContextDecorator
):
# Context manager without decision branch logic overhead
def
__init__
(
self
,
*
args
,
**
kwargs
):
pass
def
__enter__
(
self
):
return
self
def
__exit__
(
self
,
*
args
):
return
False
ProfilingContext
=
_ProfilingContext
ProfilingContext4Debug
=
_ProfilingContext
if
ENABLE_PROFILING_DEBUG
else
_NullContext
scripts/run_hunyuan_i2v.sh
View file @
cbf7820f
...
@@ -23,6 +23,8 @@ fi
...
@@ -23,6 +23,8 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
hunyuan
\
--model_cls
hunyuan
\
--model_path
$model_path
\
--model_path
$model_path
\
...
...
scripts/run_hunyuan_t2v.sh
View file @
cbf7820f
...
@@ -23,6 +23,8 @@ fi
...
@@ -23,6 +23,8 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
hunyuan
\
--model_cls
hunyuan
\
--model_path
$model_path
\
--model_path
$model_path
\
...
...
scripts/run_hunyuan_t2v_dist.sh
View file @
cbf7820f
...
@@ -23,6 +23,7 @@ fi
...
@@ -23,6 +23,7 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
torchrun
--nproc_per_node
=
4
${
lightx2v_path
}
/lightx2v/__main__.py
\
torchrun
--nproc_per_node
=
4
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
hunyuan
\
--model_cls
hunyuan
\
...
...
scripts/run_hunyuan_t2v_taylorseer.sh
View file @
cbf7820f
...
@@ -23,6 +23,7 @@ fi
...
@@ -23,6 +23,7 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
hunyuan
\
--model_cls
hunyuan
\
...
...
scripts/run_wan_i2v.sh
View file @
cbf7820f
...
@@ -29,6 +29,7 @@ fi
...
@@ -29,6 +29,7 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
wan2.1
\
--model_cls
wan2.1
\
...
...
scripts/run_wan_i2v_with_lora.sh
View file @
cbf7820f
...
@@ -26,6 +26,8 @@ fi
...
@@ -26,6 +26,8 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
python
-m
lightx2v
\
python
-m
lightx2v
\
--model_cls
wan2.1
\
--model_cls
wan2.1
\
--task
i2v
\
--task
i2v
\
...
...
scripts/run_wan_t2v.sh
View file @
cbf7820f
...
@@ -29,6 +29,7 @@ fi
...
@@ -29,6 +29,7 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
python
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
wan2.1
\
--model_cls
wan2.1
\
...
...
scripts/run_wan_t2v_dist.sh
View file @
cbf7820f
...
@@ -29,6 +29,7 @@ fi
...
@@ -29,6 +29,7 @@ fi
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
PYTHONPATH
=
${
lightx2v_path
}
:
$PYTHONPATH
export
ENABLE_PROFILING_DEBUG
=
true
torchrun
--nproc_per_node
=
4
${
lightx2v_path
}
/lightx2v/__main__.py
\
torchrun
--nproc_per_node
=
4
${
lightx2v_path
}
/lightx2v/__main__.py
\
--model_cls
wan2.1
\
--model_cls
wan2.1
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment