Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
878f5a48
Commit
878f5a48
authored
Aug 04, 2025
by
helloyongyang
Browse files
Fix torch compile
parent
88b7a2dd
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
14 additions
and
14 deletions
+14
-14
lightx2v/infer.py
lightx2v/infer.py
+1
-1
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+1
-0
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+1
-0
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+3
-3
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+8
-10
No files found.
lightx2v/infer.py
View file @
878f5a48
...
@@ -27,8 +27,8 @@ def init_runner(config):
...
@@ -27,8 +27,8 @@ def init_runner(config):
if
CHECK_ENABLE_GRAPH_MODE
():
if
CHECK_ENABLE_GRAPH_MODE
():
default_runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
default_runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
default_runner
.
init_modules
()
runner
=
GraphRunner
(
default_runner
)
runner
=
GraphRunner
(
default_runner
)
runner
.
runner
.
init_modules
()
else
:
else
:
runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
.
init_modules
()
runner
.
init_modules
()
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
878f5a48
...
@@ -12,6 +12,7 @@ class WanPostInfer:
...
@@ -12,6 +12,7 @@ class WanPostInfer:
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
):
if
e
.
dim
()
==
2
:
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
878f5a48
...
@@ -28,6 +28,7 @@ class WanPreInfer:
...
@@ -28,6 +28,7 @@ class WanPreInfer:
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
inputs
,
positive
,
kv_start
=
0
,
kv_end
=
0
):
def
infer
(
self
,
weights
,
inputs
,
positive
,
kv_start
=
0
,
kv_end
=
0
):
x
=
self
.
scheduler
.
latents
x
=
self
.
scheduler
.
latents
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
878f5a48
...
@@ -6,7 +6,7 @@ from lightx2v.utils.envs import *
...
@@ -6,7 +6,7 @@ from lightx2v.utils.envs import *
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
.
tolist
()
f
,
h
,
w
=
grid_sizes
[
0
]
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
[
[
...
@@ -22,7 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
...
@@ -22,7 +22,7 @@ def compute_freqs(c, grid_sizes, freqs):
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs_audio
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
.
tolist
()
f
,
h
,
w
=
grid_sizes
[
0
]
f
=
f
+
1
##for r2v add 1 channel
f
=
f
+
1
##for r2v add 1 channel
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
...
@@ -39,7 +39,7 @@ def compute_freqs_audio(c, grid_sizes, freqs):
...
@@ -39,7 +39,7 @@ def compute_freqs_audio(c, grid_sizes, freqs):
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
.
tolist
()
f
,
h
,
w
=
grid_sizes
[
0
]
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
[
[
...
...
lightx2v/models/runners/default_runner.py
View file @
878f5a48
...
@@ -107,8 +107,9 @@ class DefaultRunner(BaseRunner):
...
@@ -107,8 +107,9 @@ class DefaultRunner(BaseRunner):
def
set_progress_callback
(
self
,
callback
):
def
set_progress_callback
(
self
,
callback
):
self
.
progress_callback
=
callback
self
.
progress_callback
=
callback
def
run
(
self
):
def
run
(
self
,
total_steps
=
None
):
total_steps
=
self
.
model
.
scheduler
.
infer_steps
if
total_steps
is
None
:
total_steps
=
self
.
model
.
scheduler
.
infer_steps
for
step_index
in
range
(
total_steps
):
for
step_index
in
range
(
total_steps
):
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
logger
.
info
(
f
"==> step_index:
{
step_index
+
1
}
/
{
total_steps
}
"
)
...
@@ -126,13 +127,10 @@ class DefaultRunner(BaseRunner):
...
@@ -126,13 +127,10 @@ class DefaultRunner(BaseRunner):
return
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
return
self
.
model
.
scheduler
.
latents
,
self
.
model
.
scheduler
.
generator
def
run_step
(
self
,
step_index
=
0
):
def
run_step
(
self
):
self
.
init_scheduler
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
inputs
=
self
.
run_input_encoder
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
set_target_shape
()
self
.
model
.
scheduler
.
step_pre
(
step_index
=
step_index
)
self
.
run_dit
(
total_steps
=
1
)
self
.
model
.
infer
(
self
.
inputs
)
self
.
model
.
scheduler
.
step_post
()
def
end_run
(
self
):
def
end_run
(
self
):
self
.
model
.
scheduler
.
clear
()
self
.
model
.
scheduler
.
clear
()
...
@@ -171,14 +169,14 @@ class DefaultRunner(BaseRunner):
...
@@ -171,14 +169,14 @@ class DefaultRunner(BaseRunner):
}
}
@
ProfilingContext
(
"Run DiT"
)
@
ProfilingContext
(
"Run DiT"
)
def
_run_dit_local
(
self
):
def
_run_dit_local
(
self
,
total_steps
=
None
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
self
.
init_scheduler
()
self
.
init_scheduler
()
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
self
.
model
.
scheduler
.
prepare
(
self
.
inputs
[
"image_encoder_output"
])
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
if
self
.
config
.
get
(
"model_cls"
)
==
"wan2.2"
and
self
.
config
[
"task"
]
==
"i2v"
:
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
self
.
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
None
latents
,
generator
=
self
.
run
()
latents
,
generator
=
self
.
run
(
total_steps
)
self
.
end_run
()
self
.
end_run
()
return
latents
,
generator
return
latents
,
generator
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment