Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
039456f2
Commit
039456f2
authored
Jul 24, 2025
by
sandy
Committed by
GitHub
Jul 24, 2025
Browse files
Merge pull request #161 from ModelTC/feat-audio
更新clip预处理
parents
048be946
7e6f9418
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
9 additions
and
77 deletions
+9
-77
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
+2
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+3
-55
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+3
-20
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-1
No files found.
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
View file @
039456f2
...
@@ -14,5 +14,6 @@
...
@@ -14,5 +14,6 @@
"sample_shift"
:
5
,
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"cpu_offload"
:
false
,
"adaptive_resize"
:
true
"adaptive_resize"
:
true
,
"use_31_block"
:
false
}
}
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
039456f2
...
@@ -11,10 +11,6 @@ import torchvision.transforms as T
...
@@ -11,10 +11,6 @@ import torchvision.transforms as T
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
VllmQuantLinearInt8
,
VllmQuantLinearFp8
,
TorchaoQuantLinearInt8
,
Q8FQuantLinearInt8
,
Q8FQuantLinearFp8
from
lightx2v.models.input_encoders.hf.q_linear
import
VllmQuantLinearInt8
,
VllmQuantLinearFp8
,
TorchaoQuantLinearInt8
,
Q8FQuantLinearInt8
,
Q8FQuantLinearFp8
from
einops
import
rearrange
from
torch
import
Tensor
from
transformers
import
CLIPVisionModel
__all__
=
[
__all__
=
[
"XLMRobertaCLIP"
,
"XLMRobertaCLIP"
,
...
@@ -448,14 +444,14 @@ class CLIPModel:
...
@@ -448,14 +444,14 @@ class CLIPModel:
def
visual
(
self
,
videos
,
args
):
def
visual
(
self
,
videos
,
args
):
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cuda
()
self
.
to_cuda
()
use_31_block
=
getattr
(
args
,
"use_31_block"
,
True
)
# preprocess
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
size
=
(
self
.
model
.
image_size
,)
*
2
videos
=
torch
.
cat
([
F
.
interpolate
(
u
.
transpose
(
0
,
1
)
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
torch
.
cat
([
F
.
interpolate
(
u
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
# forward
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
True
)
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
use_31_block
)
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
self
.
to_cpu
()
...
@@ -466,51 +462,3 @@ class CLIPModel:
...
@@ -466,51 +462,3 @@ class CLIPModel:
def
to_cpu
(
self
):
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
cpu
()
self
.
model
=
self
.
model
.
cpu
()
class
WanVideoIPHandler
:
def
__init__
(
self
,
model_name
,
repo_or_path
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
"cuda"
,
dtype
=
torch
.
float16
):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
"""720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder
=
CLIPVisionModel
.
from_pretrained
(
repo_or_path
,
torch_dtype
=
dtype
)
logger
.
info
(
f
"Using image encoder
{
model_name
}
from
{
repo_or_path
}
"
)
image_encoder
.
requires_grad_
(
require_grad
)
if
mode
==
"eval"
:
image_encoder
.
eval
()
else
:
image_encoder
.
train
()
self
.
dtype
=
dtype
self
.
device
=
device
self
.
image_encoder
=
image_encoder
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
size
=
(
224
,
224
)
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
self
.
normalize
=
T
.
Normalize
(
mean
=
mean
,
std
=
std
)
# self.image_processor = image_processor
def
encode
(
self
,
img_tensor
:
Tensor
,
):
if
img_tensor
.
ndim
==
5
:
# B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor
=
rearrange
(
img_tensor
,
"B C 1 H W -> B C H W"
)
img_tensor
=
torch
.
clamp
(
img_tensor
.
float
()
*
0.5
+
0.5
,
min
=
0.0
,
max
=
1.0
).
to
(
self
.
device
)
img_tensor
=
F
.
interpolate
(
img_tensor
,
size
=
self
.
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
img_tensor
=
self
.
normalize
(
img_tensor
).
to
(
self
.
dtype
)
image_embeds
=
self
.
image_encoder
(
pixel_values
=
img_tensor
,
output_hidden_states
=
True
)
return
image_embeds
.
hidden_states
[
-
1
]
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
039456f2
...
@@ -10,29 +10,18 @@ from dataclasses import dataclass
...
@@ -10,29 +10,18 @@ from dataclasses import dataclass
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.xlm_roberta.model
import
CLIPModel
,
WanVideoIPHandler
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerSchedulerTimestepFix
,
ConsistencyModelScheduler
from
loguru
import
logger
from
loguru
import
logger
import
torch.distributed
as
dist
from
einops
import
rearrange
from
einops
import
rearrange
import
torchaudio
as
ta
import
torchaudio
as
ta
from
transformers
import
AutoFeatureExtractor
from
transformers
import
AutoFeatureExtractor
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
resize
from
torchvision.transforms.functional
import
resize
...
@@ -618,12 +607,6 @@ class WanAudioRunner(WanRunner):
...
@@ -618,12 +607,6 @@ class WanAudioRunner(WanRunner):
return
base_model
return
base_model
def
load_image_encoder
(
self
):
"""Load image encoder"""
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
def
run_image_encoder
(
self
,
config
,
vae_model
):
def
run_image_encoder
(
self
,
config
,
vae_model
):
"""Run image encoder"""
"""Run image encoder"""
...
@@ -638,7 +621,7 @@ class WanAudioRunner(WanRunner):
...
@@ -638,7 +621,7 @@ class WanAudioRunner(WanRunner):
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
...
@@ -662,7 +645,7 @@ class WanAudioRunner(WanRunner):
...
@@ -662,7 +645,7 @@ class WanAudioRunner(WanRunner):
# Resize image to target size
# Resize image to target size
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
# Prepare for VAE encoding
# Prepare for VAE encoding
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
039456f2
...
@@ -197,7 +197,7 @@ class WanRunner(DefaultRunner):
...
@@ -197,7 +197,7 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
:,
None
,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment