Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
984cd6c9
Commit
984cd6c9
authored
Aug 01, 2025
by
gushiqiao
Committed by
GitHub
Aug 01, 2025
Browse files
Fix audio offload bug
Fix audio offload bug
parents
348822d9
77bef6e8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
29 deletions
+49
-29
configs/offload/block/wan_i2v_audio_block.json
configs/offload/block/wan_i2v_audio_block.json
+22
-0
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+10
-11
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+17
-18
No files found.
configs/offload/block/wan_i2v_audio_block.json
0 → 100644
View file @
984cd6c9
{
"infer_steps"
:
4
,
"target_fps"
:
16
,
"video_duration"
:
16
,
"audio_sr"
:
16000
,
"target_video_length"
:
81
,
"target_height"
:
720
,
"target_width"
:
1280
,
"self_attn_1_type"
:
"flash_attn3"
,
"cross_attn_1_type"
:
"flash_attn3"
,
"cross_attn_2_type"
:
"flash_attn3"
,
"seed"
:
42
,
"sample_guide_scale"
:
1
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
true
,
"offload_granularity"
:
"block"
,
"t5_cpu_offload"
:
true
,
"offload_ratio_val"
:
1
,
"t5_offload_granularity"
:
"block"
,
"use_tiling_vae"
:
true
}
lightx2v/models/networks/wan/audio_adapter.py
View file @
984cd6c9
...
@@ -2,6 +2,8 @@ try:
...
@@ -2,6 +2,8 @@ try:
import
flash_attn
import
flash_attn
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
flash_attn
=
None
flash_attn
=
None
import
os
import
safetensors
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -9,11 +11,6 @@ import torch.nn.functional as F
...
@@ -9,11 +11,6 @@ import torch.nn.functional as F
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
diffusers.models.embeddings
import
TimestepEmbedding
,
Timesteps
from
einops
import
rearrange
from
einops
import
rearrange
from
transformers
import
AutoModel
from
transformers
import
AutoModel
from
loguru
import
logger
import
os
import
safetensors
from
typing
import
List
,
Optional
,
Tuple
,
Union
def
load_safetensors
(
in_path
:
str
):
def
load_safetensors
(
in_path
:
str
):
...
@@ -370,13 +367,12 @@ class AudioAdapter(nn.Module):
...
@@ -370,13 +367,12 @@ class AudioAdapter(nn.Module):
class
AudioAdapterPipe
:
class
AudioAdapterPipe
:
def
__init__
(
def
__init__
(
self
,
audio_adapter
:
AudioAdapter
,
audio_encoder_repo
:
str
=
"microsoft/wavlm-base-plus"
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
generator
=
None
,
tgt_fps
:
int
=
15
,
weight
:
float
=
1.0
self
,
audio_adapter
:
AudioAdapter
,
audio_encoder_repo
:
str
=
"microsoft/wavlm-base-plus"
,
dtype
=
torch
.
float32
,
device
=
"cuda"
,
tgt_fps
:
int
=
15
,
weight
:
float
=
1.0
,
cpu_offload
:
bool
=
False
)
->
None
:
)
->
None
:
self
.
audio_adapter
=
audio_adapter
self
.
audio_adapter
=
audio_adapter
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
device
=
device
self
.
generator
=
generator
self
.
audio_encoder_dtype
=
torch
.
float16
self
.
audio_encoder_dtype
=
torch
.
float16
self
.
cpu_offload
=
cpu_offload
##音频编码器
##音频编码器
self
.
audio_encoder
=
AutoModel
.
from_pretrained
(
audio_encoder_repo
)
self
.
audio_encoder
=
AutoModel
.
from_pretrained
(
audio_encoder_repo
)
...
@@ -403,11 +399,14 @@ class AudioAdapterPipe:
...
@@ -403,11 +399,14 @@ class AudioAdapterPipe:
audio_length
=
int
(
50
/
self
.
tgt_fps
*
video_frame
)
audio_length
=
int
(
50
/
self
.
tgt_fps
*
video_frame
)
with
torch
.
no_grad
():
with
torch
.
no_grad
():
audio_input_feat
=
audio_input_feat
.
to
(
self
.
device
,
self
.
audio_encoder_dtype
)
try
:
try
:
audio_feat
=
self
.
audio_encoder
(
audio_input_feat
,
return_dict
=
True
).
last_hidden_state
if
self
.
cpu_offload
:
self
.
audio_encoder
=
self
.
audio_encoder
.
to
(
"cuda"
)
audio_feat
=
self
.
audio_encoder
(
audio_input_feat
.
to
(
self
.
audio_encoder_dtype
),
return_dict
=
True
).
last_hidden_state
if
self
.
cpu_offload
:
self
.
audio_encoder
=
self
.
audio_encoder
.
to
(
"cpu"
)
except
Exception
as
err
:
except
Exception
as
err
:
audio_feat
=
torch
.
rand
(
1
,
audio_length
,
self
.
audio_feature_dim
).
to
(
self
.
device
)
audio_feat
=
torch
.
rand
(
1
,
audio_length
,
self
.
audio_feature_dim
).
to
(
"cuda"
)
print
(
err
)
print
(
err
)
audio_feat
=
audio_feat
.
to
(
self
.
dtype
)
audio_feat
=
audio_feat
.
to
(
self
.
dtype
)
if
dropout_cond
is
not
None
:
if
dropout_cond
is
not
None
:
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
984cd6c9
...
@@ -2,32 +2,27 @@ import os
...
@@ -2,32 +2,27 @@ import os
import
gc
import
gc
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
import
torchvision.transforms.functional
as
TF
import
subprocess
import
torchaudio
as
ta
from
PIL
import
Image
from
PIL
import
Image
from
contextlib
import
contextmanager
from
contextlib
import
contextmanager
from
typing
import
Optional
,
Tuple
,
Union
,
List
,
Dict
,
Any
from
typing
import
Optional
,
Tuple
,
List
,
Dict
,
Any
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
loguru
import
logger
from
einops
import
rearrange
from
transformers
import
AutoFeatureExtractor
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
resize
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
,
MultiModelStruct
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
,
Wan22MoeAudioModel
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
,
Wan22MoeAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
.wan_runner
import
MultiModelStruct
from
loguru
import
logger
from
einops
import
rearrange
import
torchaudio
as
ta
from
transformers
import
AutoFeatureExtractor
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
resize
import
subprocess
import
warnings
@
contextmanager
@
contextmanager
...
@@ -424,9 +419,13 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -424,9 +419,13 @@ class WanAudioRunner(WanRunner): # type:ignore
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
)
audio_adapter
=
rank0_load_state_dict_from_path
(
audio_adapter
,
audio_adapter_path
,
strict
=
False
)
# Audio encoder
# Audio encoder
device
=
torch
.
device
(
"cuda"
)
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
if
cpu_offload
:
device
=
torch
.
device
(
"cpu"
)
else
:
device
=
torch
.
device
(
"cuda"
)
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
audio_encoder_repo
=
self
.
config
[
"model_path"
]
+
"/audio_encoder"
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
generator
=
torch
.
Generator
(
device
),
weight
=
1.0
)
self
.
_audio_adapter_pipe
=
AudioAdapterPipe
(
audio_adapter
,
audio_encoder_repo
=
audio_encoder_repo
,
dtype
=
torch
.
bfloat16
,
device
=
device
,
weight
=
1.0
,
cpu_offload
=
cpu_offload
)
return
self
.
_audio_adapter_pipe
return
self
.
_audio_adapter_pipe
...
@@ -622,7 +621,7 @@ class WanAudioRunner(WanRunner): # type:ignore
...
@@ -622,7 +621,7 @@ class WanAudioRunner(WanRunner): # type:ignore
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
torch
.
from_numpy
(
ref_img
).
to
(
vae_model
.
device
)
ref_img
=
torch
.
from_numpy
(
ref_img
).
cuda
(
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
rearrange
(
ref_img
,
"H W C -> 1 C H W"
)
ref_img
=
ref_img
[:,
:
3
]
ref_img
=
ref_img
[:,
:
3
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment