Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
7fc021e2
Commit
7fc021e2
authored
Apr 21, 2025
by
helloyongyang
Browse files
support runners & torch.compile
parent
cbf7820f
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
75 additions
and
21 deletions
+75
-21
lightx2v/__main__.py
lightx2v/__main__.py
+12
-17
lightx2v/models/input_encoders/hf/__init__.py
lightx2v/models/input_encoders/hf/__init__.py
+0
-0
lightx2v/models/networks/__init__.py
lightx2v/models/networks/__init__.py
+0
-0
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+2
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+2
-0
lightx2v/models/runners/__init__.py
lightx2v/models/runners/__init__.py
+0
-0
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+27
-0
lightx2v/models/runners/graph_runner.py
lightx2v/models/runners/graph_runner.py
+23
-0
lightx2v/utils/envs.py
lightx2v/utils/envs.py
+8
-0
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+1
-4
No files found.
lightx2v/__main__.py
View file @
7fc021e2
...
...
@@ -10,6 +10,7 @@ import torchvision.transforms.functional as TF
import
numpy
as
np
from
PIL
import
Image
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
save_videos_grid
,
seed_all
,
cache_video
from
lightx2v.utils.profiler
import
ProfilingContext
,
ProfilingContext4Debug
from
lightx2v.utils.set_config
import
set_config
...
...
@@ -32,6 +33,9 @@ from lightx2v.models.networks.wan.lora_adapter import WanLoraWrapper
from
lightx2v.models.video_encoders.hf.autoencoder_kl_causal_3d.model
import
VideoEncoderKLCausal3DModel
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.common.ops
import
*
...
...
@@ -271,22 +275,6 @@ def init_scheduler(config, image_encoder_output):
return
scheduler
def
run_main_inference
(
model
,
inputs
):
for
step_index
in
range
(
model
.
scheduler
.
infer_steps
):
print
(
f
"==> step_index:
{
step_index
+
1
}
/
{
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"infer"
):
model
.
infer
(
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
model
.
scheduler
.
step_post
()
return
model
.
scheduler
.
latents
,
model
.
scheduler
.
generator
def
run_vae
(
latents
,
generator
,
config
):
images
=
vae_model
.
decode
(
latents
,
generator
=
generator
,
config
=
config
)
return
images
...
...
@@ -358,7 +346,14 @@ if __name__ == "__main__":
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
latents
,
generator
=
run_main_inference
(
model
,
inputs
)
if
ENABLE_GRAPH_MODE
:
default_runner
=
DefaultRunner
(
model
,
inputs
)
runner
=
GraphRunner
(
default_runner
)
else
:
runner
=
DefaultRunner
(
model
,
inputs
)
latents
,
generator
=
runner
.
run
()
if
config
.
cpu_offload
:
scheduler
.
clear
()
...
...
lightx2v/models/input_encoders/hf/__init__.py
0 → 100755
View file @
7fc021e2
lightx2v/models/networks/__init__.py
0 → 100755
View file @
7fc021e2
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
7fc021e2
...
...
@@ -3,6 +3,7 @@ from einops import rearrange
from
lightx2v.attentions
import
attention
from
.utils_bf16
import
apply_rotary_emb
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.utils.envs
import
*
class
HunyuanTransformerInfer
:
...
...
@@ -25,6 +26,7 @@ class HunyuanTransformerInfer:
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
ENABLE_GRAPH_MODE
)
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
7fc021e2
...
...
@@ -2,6 +2,7 @@ import torch
from
.utils
import
compute_freqs
,
compute_freqs_dist
,
apply_rotary_emb
from
lightx2v.attentions
import
attention
from
lightx2v.common.offload.manager
import
WeightStreamManager
from
lightx2v.utils.envs
import
*
class
WanTransformerInfer
:
...
...
@@ -34,6 +35,7 @@ class WanTransformerInfer:
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
@
torch
.
compile
(
disable
=
not
ENABLE_GRAPH_MODE
)
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
)
...
...
lightx2v/models/runners/__init__.py
0 → 100755
View file @
7fc021e2
lightx2v/models/runners/default_runner.py
0 → 100644
View file @
7fc021e2
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
class
DefaultRunner
:
def
__init__
(
self
,
model
,
inputs
):
self
.
model
=
model
self
.
inputs
=
inputs
def
run
(
self
):
for
step_index
in
range
(
self
.
model
.
scheduler
.
infer_steps
):
print
(
f
"==> step_index:
{
step_index
+
1
}
/
{
self
.
model
.
scheduler
.
infer_steps
}
"
)
with
ProfilingContext4Debug
(
"step_pre"
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
with
ProfilingContext4Debug
(
"infer"
):
self
.
model
.
infer
(
self
.
inputs
)
with
ProfilingContext4Debug
(
"step_post"
):
self
.
model
.
scheduler
.
step_post
()
return
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
def
run_step
(
self
,
step_index
=
0
):
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
scheduler
.
step_post
()
lightx2v/models/runners/graph_runner.py
0 → 100644
View file @
7fc021e2
import
copy
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
class
GraphRunner
:
def
__init__
(
self
,
runner
):
self
.
runner
=
runner
self
.
compile
()
def
compile
(
self
):
scheduler
=
copy
.
deepcopy
(
self
.
runner
.
model
.
scheduler
)
inputs
=
copy
.
deepcopy
(
self
.
runner
.
inputs
)
print
(
"start compile..."
)
with
ProfilingContext4Debug
(
"compile"
):
self
.
runner
.
run_step
()
print
(
"end compile..."
)
self
.
runner
.
model
.
set_scheduler
(
scheduler
)
setattr
(
self
.
runner
,
"inputs"
,
inputs
)
def
run
(
self
):
return
self
.
runner
.
run
()
lightx2v/utils/envs.py
0 → 100644
View file @
7fc021e2
import
os
global
ENABLE_PROFILING_DEBUG
ENABLE_PROFILING_DEBUG
=
os
.
getenv
(
"ENABLE_PROFILING_DEBUG"
,
"false"
).
lower
()
==
"true"
global
ENABLE_GRAPH_MODE
ENABLE_GRAPH_MODE
=
os
.
getenv
(
"ENABLE_GRAPH_MODE"
,
"false"
).
lower
()
==
"true"
lightx2v/utils/profiler.py
View file @
7fc021e2
import
time
import
os
import
torch
from
contextlib
import
ContextDecorator
ENABLE_PROFILING_DEBUG
=
os
.
getenv
(
"ENABLE_PROFILING_DEBUG"
,
"false"
).
lower
()
==
"true"
from
lightx2v.utils.envs
import
*
class
_ProfilingContext
(
ContextDecorator
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment