Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
701075f4
Commit
701075f4
authored
Sep 15, 2025
by
Yang Yong(雍洋)
Committed by
GitHub
Sep 15, 2025
Browse files
refactor compiler (#301)
parent
60c421f4
Changes
66
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
160 additions
and
71 deletions
+160
-71
configs/seko_talk/seko_talk_15_base_compile.json
configs/seko_talk/seko_talk_15_base_compile.json
+19
-0
configs/seko_talk/seko_talk_16_fp8_dist_compile.json
configs/seko_talk/seko_talk_16_fp8_dist_compile.json
+30
-0
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+0
-2
lightx2v/common/ops/attn/sage_attn.py
lightx2v/common/ops/attn/sage_attn.py
+0
-1
lightx2v/common/ops/attn/torch_sdpa.py
lightx2v/common/ops/attn/torch_sdpa.py
+0
-1
lightx2v/common/ops/attn/ulysses_attn.py
lightx2v/common/ops/attn/ulysses_attn.py
+9
-5
lightx2v/deploy/worker/hub.py
lightx2v/deploy/worker/hub.py
+1
-8
lightx2v/infer.py
lightx2v/infer.py
+4
-9
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
+0
-1
lightx2v/models/networks/qwen_image/infer/pre_infer.py
lightx2v/models/networks/qwen_image/infer/pre_infer.py
+0
-1
lightx2v/models/networks/wan/audio_model.py
lightx2v/models/networks/wan/audio_model.py
+68
-0
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+0
-6
lightx2v/models/networks/wan/infer/audio/post_infer.py
lightx2v/models/networks/wan/infer/audio/post_infer.py
+1
-1
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+15
-15
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+0
-2
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+1
-0
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+6
-7
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-2
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+4
-8
lightx2v/models/runners/cogvideox/cogvidex_runner.py
lightx2v/models/runners/cogvideox/cogvidex_runner.py
+1
-2
No files found.
configs/seko_talk/seko_talk_15_base_compile.json
0 → 100644
View file @
701075f4
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
360
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"compile"
:
true
,
"compile_shapes"
:
[[
480
,
832
],
[
544
,
960
],
[
720
,
1280
],
[
832
,
480
],
[
960
,
544
],
[
1280
,
720
],
[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]]
}
configs/seko_talk/seko_talk_16_fp8_dist_compile.json
0 → 100755
View file @
701075f4
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
360
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"resize_mode"
:
"adaptive"
,
"self_attn_1_type"
:
"sage_attn2"
,
"cross_attn_1_type"
:
"sage_attn2"
,
"cross_attn_2_type"
:
"sage_attn2"
,
"seed"
:
42
,
"sample_guide_scale"
:
1.0
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"use_31_block"
:
false
,
"parallel"
:
{
"seq_p_size"
:
8
,
"seq_p_attn_type"
:
"ulysses"
},
"mm_config"
:
{
"mm_type"
:
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
},
"adapter_quantized"
:
true
,
"adapter_quant_scheme"
:
"fp8"
,
"t5_quantized"
:
true
,
"t5_quant_scheme"
:
"fp8"
,
"compile"
:
true
,
"compile_shapes"
:
[[
480
,
832
],
[
544
,
960
],
[
720
,
1280
],
[
832
,
480
],
[
960
,
544
],
[
1280
,
720
],
[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]]
}
lightx2v/common/ops/attn/flash_attn.py
View file @
701075f4
...
@@ -33,7 +33,6 @@ class FlashAttn2Weight(AttnWeightTemplate):
...
@@ -33,7 +33,6 @@ class FlashAttn2Weight(AttnWeightTemplate):
max_seqlen_q
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
):
x
=
flash_attn_varlen_func
(
x
=
flash_attn_varlen_func
(
q
,
q
,
...
@@ -62,7 +61,6 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -62,7 +61,6 @@ class FlashAttn3Weight(AttnWeightTemplate):
max_seqlen_q
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
):
x
=
flash_attn_varlen_func_v3
(
x
=
flash_attn_varlen_func_v3
(
q
,
q
,
...
...
lightx2v/common/ops/attn/sage_attn.py
View file @
701075f4
...
@@ -34,7 +34,6 @@ class SageAttn2Weight(AttnWeightTemplate):
...
@@ -34,7 +34,6 @@ class SageAttn2Weight(AttnWeightTemplate):
max_seqlen_q
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
):
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
q
,
k
,
v
=
q
.
contiguous
(),
k
.
contiguous
(),
v
.
contiguous
()
if
model_cls
==
"hunyuan"
:
if
model_cls
==
"hunyuan"
:
...
...
lightx2v/common/ops/attn/torch_sdpa.py
View file @
701075f4
...
@@ -24,7 +24,6 @@ class TorchSDPAWeight(AttnWeightTemplate):
...
@@ -24,7 +24,6 @@ class TorchSDPAWeight(AttnWeightTemplate):
max_seqlen_q
=
None
,
max_seqlen_q
=
None
,
max_seqlen_kv
=
None
,
max_seqlen_kv
=
None
,
model_cls
=
None
,
model_cls
=
None
,
mask_map
=
None
,
):
):
if
q
.
ndim
==
3
:
if
q
.
ndim
==
3
:
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
q
,
k
,
v
=
q
.
unsqueeze
(
0
),
k
.
unsqueeze
(
0
),
v
.
unsqueeze
(
0
)
...
...
lightx2v/common/ops/attn/ulysses_attn.py
View file @
701075f4
...
@@ -86,15 +86,19 @@ class UlyssesAttnWeight(AttnWeightTemplate):
...
@@ -86,15 +86,19 @@ class UlyssesAttnWeight(AttnWeightTemplate):
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
gathered_txt_attn
=
[
torch
.
empty_like
(
txt_attn
)
for
_
in
range
(
world_size
)]
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
,
group
=
seq_p_group
)
dist
.
all_gather
(
gathered_txt_attn
,
txt_attn
,
group
=
seq_p_group
)
# 处理图像注意力结果
img_attn
=
self
.
_reshape_img_attn
(
img_attn
,
world_size
,
shard_seqlen
,
shard_heads
,
hidden_dims
,
seq_p_group
)
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
# 将头的格式转换回序列格式
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
txt_attn
=
torch
.
cat
(
gathered_txt_attn
,
dim
=
1
)
# 合并所有进程的文本注意力结果
txt_attn
=
torch
.
cat
(
gathered_txt_attn
,
dim
=
1
)
# 合并所有进程的文本注意力结果
# 合并图像和文本的注意力结果
# 合并图像和文本的注意力结果
attn
=
torch
.
cat
([
img_attn
,
txt_attn
],
dim
=
0
)
attn
=
torch
.
cat
([
img_attn
,
txt_attn
],
dim
=
0
)
return
attn
# 返回最终的注意力结果
return
attn
# 返回最终的注意力结果
@
torch
.
compiler
.
disable
def
_reshape_img_attn
(
self
,
img_attn
,
world_size
,
shard_seqlen
,
shard_heads
,
hidden_dims
,
seq_p_group
):
img_attn
=
img_attn
.
reshape
(
world_size
*
shard_seqlen
,
shard_heads
,
hidden_dims
)
# 重塑图像注意力结果
img_attn
=
all2all_head2seq
(
img_attn
,
group
=
seq_p_group
)
# 将头的格式转换回序列格式
img_attn
=
img_attn
.
reshape
(
shard_seqlen
,
-
1
)
# 重塑为 [shard_seqlen, -1] 形状
torch
.
cuda
.
synchronize
()
# 确保CUDA操作完成
return
img_attn
lightx2v/deploy/worker/hub.py
View file @
701075f4
...
@@ -14,8 +14,6 @@ from loguru import logger
...
@@ -14,8 +14,6 @@ from loguru import logger
from
lightx2v.deploy.common.utils
import
class_try_catch_async
from
lightx2v.deploy.common.utils
import
class_try_catch_async
from
lightx2v.infer
import
init_runner
# noqa
from
lightx2v.infer
import
init_runner
# noqa
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.utils.envs
import
CHECK_ENABLE_GRAPH_MODE
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
from
lightx2v.utils.set_config
import
set_config
,
set_parallel_config
...
@@ -189,12 +187,7 @@ class PipelineWorker(BaseWorker):
...
@@ -189,12 +187,7 @@ class PipelineWorker(BaseWorker):
def
__init__
(
self
,
args
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
super
().
__init__
(
args
)
self
.
runner
.
init_modules
()
self
.
runner
.
init_modules
()
if
CHECK_ENABLE_GRAPH_MODE
():
self
.
run_func
=
self
.
runner
.
run_pipeline
self
.
init_temp_params
()
self
.
graph_runner
=
GraphRunner
(
self
.
runner
)
self
.
run_func
=
self
.
graph_runner
.
run_pipeline
else
:
self
.
run_func
=
self
.
runner
.
run_pipeline
def
init_temp_params
(
self
):
def
init_temp_params
(
self
):
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
cur_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
...
...
lightx2v/infer.py
View file @
701075f4
import
argparse
import
argparse
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.common.ops
import
*
from
lightx2v.common.ops
import
*
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.cogvideox.cogvidex_runner
import
CogvideoxRunner
# noqa: F401
from
lightx2v.models.runners.graph_runner
import
GraphRunner
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.hunyuan.hunyuan_runner
import
HunyuanRunner
# noqa: F401
from
lightx2v.models.runners.qwen_image.qwen_image_runner
import
QwenImageRunner
# noqa: F401
from
lightx2v.models.runners.qwen_image.qwen_image_runner
import
QwenImageRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
WanAudioRunner
# noqa: F401
from
lightx2v.models.runners.wan.wan_audio_runner
import
Wan22AudioRunner
,
WanAudioRunner
# noqa: F401
...
@@ -23,14 +23,9 @@ from lightx2v.utils.utils import seed_all
...
@@ -23,14 +23,9 @@ from lightx2v.utils.utils import seed_all
def
init_runner
(
config
):
def
init_runner
(
config
):
seed_all
(
config
.
seed
)
seed_all
(
config
.
seed
)
torch
.
set_grad_enabled
(
False
)
if
CHECK_ENABLE_GRAPH_MODE
():
runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
default_runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
.
init_modules
()
default_runner
.
init_modules
()
runner
=
GraphRunner
(
default_runner
)
else
:
runner
=
RUNNER_REGISTER
[
config
.
model_cls
](
config
)
runner
.
init_modules
()
return
runner
return
runner
...
...
lightx2v/models/networks/hunyuan/infer/transformer_infer.py
View file @
701075f4
...
@@ -29,7 +29,6 @@ class HunyuanTransformerInfer(BaseTransformerInfer):
...
@@ -29,7 +29,6 @@ class HunyuanTransformerInfer(BaseTransformerInfer):
else
:
else
:
self
.
infer_func
=
self
.
_infer_without_offload
self
.
infer_func
=
self
.
_infer_without_offload
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
def
infer
(
self
,
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
=
None
,
frist_frame_token_num
=
None
):
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
return
self
.
infer_func
(
weights
,
img
,
txt
,
vec
,
cu_seqlens_qkv
,
max_seqlen_qkv
,
freqs_cis
,
token_replace_vec
,
frist_frame_token_num
)
...
...
lightx2v/models/networks/qwen_image/infer/pre_infer.py
View file @
701075f4
...
@@ -24,7 +24,6 @@ class QwenImagePreInfer:
...
@@ -24,7 +24,6 @@ class QwenImagePreInfer:
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
hidden_states
,
timestep
,
guidance
,
encoder_hidden_states_mask
,
encoder_hidden_states
,
img_shapes
,
txt_seq_lens
,
attention_kwargs
):
def
infer
(
self
,
hidden_states
,
timestep
,
guidance
,
encoder_hidden_states_mask
,
encoder_hidden_states
,
img_shapes
,
txt_seq_lens
,
attention_kwargs
):
hidden_states_0
=
hidden_states
hidden_states_0
=
hidden_states
hidden_states
=
self
.
img_in
(
hidden_states
)
hidden_states
=
self
.
img_in
(
hidden_states
)
...
...
lightx2v/models/networks/wan/audio_model.py
View file @
701075f4
import
os
import
os
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.post_infer
import
WanAudioPostInfer
from
lightx2v.models.networks.wan.infer.audio.pre_infer
import
WanAudioPreInfer
from
lightx2v.models.networks.wan.infer.audio.pre_infer
import
WanAudioPreInfer
...
@@ -46,3 +48,69 @@ class WanAudioModel(WanModel):
...
@@ -46,3 +48,69 @@ class WanAudioModel(WanModel):
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
pre_infer_class
=
WanAudioPreInfer
self
.
post_infer_class
=
WanAudioPostInfer
self
.
post_infer_class
=
WanAudioPostInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
self
.
transformer_infer_class
=
WanAudioTransformerInfer
def
get_graph_name
(
self
,
shape
):
return
f
"graph_
{
shape
[
0
]
}
x
{
shape
[
1
]
}
"
def
start_compile
(
self
,
shape
):
graph_name
=
self
.
get_graph_name
(
shape
)
logger
.
info
(
f
"[Compile] Compile shape:
{
shape
}
, graph_name:
{
graph_name
}
"
)
target_video_length
=
self
.
config
.
get
(
"target_video_length"
,
81
)
latents_length
=
(
target_video_length
-
1
)
//
16
*
4
+
1
latents_h
=
shape
[
0
]
//
self
.
config
.
vae_stride
[
1
]
latents_w
=
shape
[
1
]
//
self
.
config
.
vae_stride
[
2
]
new_inputs
=
{}
new_inputs
[
"text_encoder_output"
]
=
{}
new_inputs
[
"text_encoder_output"
][
"context"
]
=
torch
.
randn
(
1
,
512
,
4096
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"text_encoder_output"
][
"context_null"
]
=
torch
.
randn
(
1
,
512
,
4096
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"image_encoder_output"
]
=
{}
new_inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
=
torch
.
randn
(
257
,
1280
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"image_encoder_output"
][
"vae_encoder_out"
]
=
torch
.
randn
(
16
,
1
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"audio_encoder_output"
]
=
torch
.
randn
(
1
,
latents_length
,
128
,
1024
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"previmg_encoder_output"
]
=
{}
new_inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
=
torch
.
randn
(
16
,
latents_length
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
new_inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
=
torch
.
randn
(
4
,
latents_length
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
self
.
scheduler
.
latents
=
torch
.
randn
(
16
,
latents_length
,
latents_h
,
latents_w
,
dtype
=
torch
.
bfloat16
).
cuda
()
self
.
scheduler
.
timestep_input
=
torch
.
tensor
([
600.0
],
dtype
=
torch
.
float32
).
cuda
()
self
.
scheduler
.
audio_adapter_t_emb
=
torch
.
randn
(
1
,
3
,
5120
,
dtype
=
torch
.
bfloat16
).
cuda
()
self
.
_infer_cond_uncond
(
new_inputs
,
infer_condition
=
True
,
graph_name
=
graph_name
)
def
compile
(
self
,
compile_shapes
):
self
.
check_compile_shapes
(
compile_shapes
)
self
.
enable_compile_mode
(
"_infer_cond_uncond"
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
0
:
self
.
to_cuda
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cuda
()
self
.
transformer_weights
.
non_block_weights_to_cuda
()
for
shape
in
compile_shapes
:
self
.
start_compile
(
shape
)
if
self
.
cpu_offload
:
if
self
.
offload_granularity
==
"model"
and
self
.
scheduler
.
step_index
==
self
.
scheduler
.
infer_steps
-
1
:
self
.
to_cpu
()
elif
self
.
offload_granularity
!=
"model"
:
self
.
pre_weight
.
to_cpu
()
self
.
transformer_weights
.
non_block_weights_to_cpu
()
self
.
disable_compile_mode
(
"_infer_cond_uncond"
)
logger
.
info
(
f
"[Compile] Compile status:
{
self
.
get_compile_status
()
}
"
)
def
check_compile_shapes
(
self
,
compile_shapes
):
for
shape
in
compile_shapes
:
assert
shape
in
[[
480
,
832
],
[
544
,
960
],
[
720
,
1280
],
[
832
,
480
],
[
960
,
544
],
[
1280
,
720
],
[
480
,
480
],
[
576
,
576
],
[
704
,
704
],
[
960
,
960
]]
def
select_graph_for_compile
(
self
):
logger
.
info
(
f
"tgt_h, tgt_w :
{
self
.
config
.
get
(
'tgt_h'
)
}
,
{
self
.
config
.
get
(
'tgt_w'
)
}
"
)
self
.
select_graph
(
"_infer_cond_uncond"
,
f
"graph_
{
self
.
config
.
get
(
'tgt_h'
)
}
x
{
self
.
config
.
get
(
'tgt_w'
)
}
"
)
logger
.
info
(
f
"[Compile] Compile status:
{
self
.
get_compile_status
()
}
"
)
lightx2v/models/networks/wan/causvid_model.py
View file @
701075f4
...
@@ -2,7 +2,6 @@ import os
...
@@ -2,7 +2,6 @@ import os
import
torch
import
torch
from
lightx2v.common.ops.attn.radial_attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.causvid.transformer_infer
import
(
from
lightx2v.models.networks.wan.infer.causvid.transformer_infer
import
(
WanTransformerInferCausVid
,
WanTransformerInferCausVid
,
)
)
...
@@ -45,11 +44,6 @@ class WanCausVidModel(WanModel):
...
@@ -45,11 +44,6 @@ class WanCausVidModel(WanModel):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
inputs
,
kv_start
,
kv_end
):
def
infer
(
self
,
inputs
,
kv_start
,
kv_end
):
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
if
self
.
config
[
"cpu_offload"
]:
if
self
.
config
[
"cpu_offload"
]:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
transformer_weights
.
post_weights_to_cuda
()
self
.
transformer_weights
.
post_weights_to_cuda
()
...
...
lightx2v/models/networks/wan/infer/audio/post_infer.py
View file @
701075f4
...
@@ -8,7 +8,7 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -8,7 +8,7 @@ class WanAudioPostInfer(WanPostInfer):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
()
)
@
torch
.
no_grad
()
def
infer
(
self
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]]
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]]
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
701075f4
...
@@ -23,29 +23,30 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -23,29 +23,30 @@ class WanAudioPreInfer(WanPreInfer):
).
cuda
()
).
cuda
()
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
inputs
):
def
infer
(
self
,
weights
,
inputs
):
infer_condition
,
latents
,
timestep_input
=
self
.
scheduler
.
infer_condition
,
self
.
scheduler
.
latents
,
self
.
scheduler
.
timestep_input
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
]
hidden_states
=
self
.
scheduler
.
latents
hidden_states
=
latents
if
self
.
config
.
model_cls
!=
"wan2.2_audio"
:
if
self
.
config
.
model_cls
!=
"wan2.2_audio"
:
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
prev_mask
=
inputs
[
"previmg_encoder_output"
][
"prev_mask"
]
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
0
)
hidden_states
=
torch
.
cat
([
hidden_states
,
prev_mask
,
prev_latents
],
dim
=
0
)
x
=
hidden_states
x
=
hidden_states
t
=
self
.
scheduler
.
timestep_input
t
=
timestep_input
if
self
.
scheduler
.
infer_condition
:
if
infer_condition
:
context
=
inputs
[
"text_encoder_output"
][
"context"
]
context
=
inputs
[
"text_encoder_output"
][
"context"
]
else
:
else
:
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
context
=
inputs
[
"text_encoder_output"
][
"context_null"
]
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
clip_fea
=
inputs
[
"image_encoder_output"
][
"clip_encoder_out"
]
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
].
to
(
self
.
scheduler
.
latents
.
dtype
)
ref_image_encoder
=
inputs
[
"image_encoder_output"
][
"vae_encoder_out"
].
to
(
latents
.
dtype
)
num_channels
,
_
,
height
,
width
=
x
.
shape
num_channels
,
_
,
height
,
width
=
x
.
shape
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
ref_num_channels
,
ref_num_frames
,
_
,
_
=
ref_image_encoder
.
shape
...
@@ -53,15 +54,15 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -53,15 +54,15 @@ class WanAudioPreInfer(WanPreInfer):
if
ref_num_channels
!=
num_channels
:
if
ref_num_channels
!=
num_channels
:
zero_padding
=
torch
.
zeros
(
zero_padding
=
torch
.
zeros
(
(
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
(
num_channels
-
ref_num_channels
,
ref_num_frames
,
height
,
width
),
dtype
=
self
.
scheduler
.
latents
.
dtype
,
dtype
=
latents
.
dtype
,
device
=
self
.
scheduler
.
latents
.
device
,
device
=
latents
.
device
,
)
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
0
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
0
)
y
=
ref_image_encoder
y
=
ref_image_encoder
# embeddings
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
grid_sizes
_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
...
@@ -70,8 +71,8 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -70,8 +71,8 @@ class WanAudioPreInfer(WanPreInfer):
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
).
squeeze
(
0
)
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
).
squeeze
(
0
)
####for r2v # zero temporl component corresponding to ref embeddings
####for r2v # zero temporl component corresponding to ref embeddings
self
.
freqs
[
grid_sizes
[
0
][
0
]
:,
:
self
.
rope_t_dim
]
=
0
self
.
freqs
[
grid_sizes
_t
:,
:
self
.
rope_t_dim
]
=
0
grid_sizes
[:,
0
]
+=
1
grid_sizes
_t
+=
1
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
...
@@ -85,15 +86,14 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -85,15 +86,14 @@ class WanAudioPreInfer(WanPreInfer):
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
# text embeddings
# text embeddings
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
out
=
weights
.
text_embedding_0
.
apply
(
context
.
squeeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
else
:
else
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
weights
.
text_embedding_0
.
apply
(
context
.
squeeze
(
0
))
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
out
,
stacked
del
out
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
...
@@ -114,7 +114,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -114,7 +114,7 @@ class WanAudioPreInfer(WanPreInfer):
del
context_clip
del
context_clip
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
grid_sizes
=
GridOutput
(
tensor
=
grid_sizes
,
tuple
=
(
grid_sizes
[
0
][
0
].
item
(),
grid_sizes
[
0
][
1
].
item
(),
grid_sizes
[
0
][
2
].
item
()
))
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
return
WanPreInferModuleOutput
(
return
WanPreInferModuleOutput
(
embed
=
embed
,
embed
=
embed
,
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
...
...
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
701075f4
...
@@ -46,7 +46,6 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
...
@@ -46,7 +46,6 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
self
.
crossattn_cache
=
crossattn_cache
self
.
crossattn_cache
=
crossattn_cache
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
kv_start
,
kv_end
):
def
infer
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
kv_start
,
kv_end
):
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
kv_start
,
kv_end
)
return
self
.
infer_func
(
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
kv_start
,
kv_end
)
...
@@ -127,7 +126,6 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
...
@@ -127,7 +126,6 @@ class WanTransformerInferCausVid(WanOffloadTransformerInfer):
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
mask_map
=
self
.
mask_map
,
)
)
else
:
else
:
# TODO: Implement parallel attention for causvid inference
# TODO: Implement parallel attention for causvid inference
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
701075f4
...
@@ -14,6 +14,7 @@ class WanPostInfer:
...
@@ -14,6 +14,7 @@ class WanPostInfer:
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
no_grad
()
def
infer
(
self
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
.
tuple
)
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
.
tuple
)
...
...
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
701075f4
...
@@ -23,7 +23,6 @@ class WanPreInfer:
...
@@ -23,7 +23,6 @@ class WanPreInfer:
).
cuda
()
).
cuda
()
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
self
.
cfg_scale
=
config
.
get
(
"cfg_scale"
,
4.0
)
self
.
cfg_scale
=
config
.
get
(
"cfg_scale"
,
4.0
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
...
@@ -32,6 +31,7 @@ class WanPreInfer:
...
@@ -32,6 +31,7 @@ class WanPreInfer:
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
inputs
,
kv_start
=
0
,
kv_end
=
0
):
def
infer
(
self
,
weights
,
inputs
,
kv_start
=
0
,
kv_end
=
0
):
x
=
self
.
scheduler
.
latents
x
=
self
.
scheduler
.
latents
t
=
self
.
scheduler
.
timestep_input
t
=
self
.
scheduler
.
timestep_input
...
@@ -61,7 +61,7 @@ class WanPreInfer:
...
@@ -61,7 +61,7 @@ class WanPreInfer:
# embeddings
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
grid_sizes
_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
...
@@ -84,15 +84,14 @@ class WanPreInfer:
...
@@ -84,15 +84,14 @@ class WanPreInfer:
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
# text embeddings
# text embeddings
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
out
=
weights
.
text_embedding_0
.
apply
(
context
.
squeeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
else
:
else
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
weights
.
text_embedding_0
.
apply
(
context
.
squeeze
(
0
))
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
out
,
stacked
del
out
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
...
@@ -117,7 +116,7 @@ class WanPreInfer:
...
@@ -117,7 +116,7 @@ class WanPreInfer:
del
context_clip
del
context_clip
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
grid_sizes
=
GridOutput
(
tensor
=
grid_sizes
,
tuple
=
(
grid_sizes
[
0
][
0
].
item
(),
grid_sizes
[
0
][
1
].
item
(),
grid_sizes
[
0
][
2
].
item
()
))
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
return
WanPreInferModuleOutput
(
return
WanPreInferModuleOutput
(
embed
=
embed
,
embed
=
embed
,
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
701075f4
...
@@ -26,7 +26,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -26,7 +26,6 @@ class WanTransformerInfer(BaseTransformerInfer):
else
:
else
:
self
.
apply_rotary_emb_func
=
apply_rotary_emb
self
.
apply_rotary_emb_func
=
apply_rotary_emb
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
mask_map
=
None
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
...
@@ -49,6 +48,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -49,6 +48,7 @@ class WanTransformerInfer(BaseTransformerInfer):
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
return
freqs_i
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
...
@@ -186,7 +186,6 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -186,7 +186,6 @@ class WanTransformerInfer(BaseTransformerInfer):
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_q
=
q
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
max_seqlen_kv
=
k
.
size
(
0
),
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
mask_map
=
self
.
mask_map
,
)
)
y
=
phase
.
self_attn_o
.
apply
(
attn_out
)
y
=
phase
.
self_attn_o
.
apply
(
attn_out
)
...
...
lightx2v/models/networks/wan/model.py
View file @
701075f4
...
@@ -7,7 +7,6 @@ import torch.nn.functional as F
...
@@ -7,7 +7,6 @@ import torch.nn.functional as F
from
loguru
import
logger
from
loguru
import
logger
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
lightx2v.common.ops.attn
import
MaskMap
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
from
lightx2v.models.networks.wan.infer.feature_caching.transformer_infer
import
(
WanTransformerInferAdaCaching
,
WanTransformerInferAdaCaching
,
WanTransformerInferCustomCaching
,
WanTransformerInferCustomCaching
,
...
@@ -30,6 +29,7 @@ from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
...
@@ -30,6 +29,7 @@ from lightx2v.models.networks.wan.weights.pre_weights import WanPreWeights
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
from
lightx2v.models.networks.wan.weights.transformer_weights
import
(
WanTransformerWeights
,
WanTransformerWeights
,
)
)
from
lightx2v.utils.custom_compiler
import
CompiledMethodsMixin
,
compiled_method
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
...
@@ -39,11 +39,12 @@ except ImportError:
...
@@ -39,11 +39,12 @@ except ImportError:
gguf
=
None
gguf
=
None
class
WanModel
:
class
WanModel
(
CompiledMethodsMixin
)
:
pre_weight_class
=
WanPreWeights
pre_weight_class
=
WanPreWeights
transformer_weight_class
=
WanTransformerWeights
transformer_weight_class
=
WanTransformerWeights
def
__init__
(
self
,
model_path
,
config
,
device
):
def
__init__
(
self
,
model_path
,
config
,
device
):
super
().
__init__
()
self
.
model_path
=
model_path
self
.
model_path
=
model_path
self
.
config
=
config
self
.
config
=
config
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
...
@@ -340,11 +341,6 @@ class WanModel:
...
@@ -340,11 +341,6 @@ class WanModel:
self
.
pre_weight
.
to_cuda
()
self
.
pre_weight
.
to_cuda
()
self
.
transformer_weights
.
non_block_weights_to_cuda
()
self
.
transformer_weights
.
non_block_weights_to_cuda
()
if
self
.
transformer_infer
.
mask_map
is
None
:
_
,
c
,
h
,
w
=
self
.
scheduler
.
latents
.
shape
video_token_num
=
c
*
(
h
//
2
)
*
(
w
//
2
)
self
.
transformer_infer
.
mask_map
=
MaskMap
(
video_token_num
,
c
)
if
self
.
config
[
"enable_cfg"
]:
if
self
.
config
[
"enable_cfg"
]:
if
self
.
config
[
"cfg_parallel"
]:
if
self
.
config
[
"cfg_parallel"
]:
# ==================== CFG Parallel Processing ====================
# ==================== CFG Parallel Processing ====================
...
@@ -378,7 +374,7 @@ class WanModel:
...
@@ -378,7 +374,7 @@ class WanModel:
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
self
.
transformer_weights
.
non_block_weights_to_cpu
()
self
.
transformer_weights
.
non_block_weights_to_cpu
()
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
()
)
@
compile
d_method
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_infer_cond_uncond
(
self
,
inputs
,
infer_condition
=
True
):
def
_infer_cond_uncond
(
self
,
inputs
,
infer_condition
=
True
):
self
.
scheduler
.
infer_condition
=
infer_condition
self
.
scheduler
.
infer_condition
=
infer_condition
...
...
lightx2v/models/runners/cogvideox/cogvidex_runner.py
View file @
701075f4
...
@@ -31,8 +31,7 @@ class CogvideoxRunner(DefaultRunner):
...
@@ -31,8 +31,7 @@ class CogvideoxRunner(DefaultRunner):
return
vae_model
,
vae_model
return
vae_model
,
vae_model
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
scheduler
=
CogvideoxXDPMScheduler
(
self
.
config
)
self
.
scheduler
=
CogvideoxXDPMScheduler
(
self
.
config
)
self
.
model
.
set_scheduler
(
scheduler
)
def
run_text_encoder
(
self
,
text
,
img
):
def
run_text_encoder
(
self
,
text
,
img
):
text_encoder_output
=
{}
text_encoder_output
=
{}
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment