Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
df68ed3f
Commit
df68ed3f
authored
Jul 24, 2025
by
wangshankun
Browse files
更新clip预处理
parent
048be946
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
11 additions
and
77 deletions
+11
-77
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
+2
-1
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+5
-55
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+3
-20
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+1
-1
No files found.
configs/audio_driven/wan_i2v_audio_adaptive_resize.json
View file @
df68ed3f
...
...
@@ -14,5 +14,6 @@
"sample_shift"
:
5
,
"enable_cfg"
:
false
,
"cpu_offload"
:
false
,
"adaptive_resize"
:
true
"adaptive_resize"
:
true
,
"use_31_block"
:
false
}
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
df68ed3f
...
...
@@ -11,10 +11,6 @@ import torchvision.transforms as T
from
lightx2v.attentions
import
attention
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
VllmQuantLinearInt8
,
VllmQuantLinearFp8
,
TorchaoQuantLinearInt8
,
Q8FQuantLinearInt8
,
Q8FQuantLinearFp8
from
einops
import
rearrange
from
torch
import
Tensor
from
transformers
import
CLIPVisionModel
__all__
=
[
"XLMRobertaCLIP"
,
...
...
@@ -448,14 +444,16 @@ class CLIPModel:
def
visual
(
self
,
videos
,
args
):
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cuda
()
use_31_block
=
True
if
hasattr
(
args
,
"use_31_block"
):
use_31_block
=
args
.
use_31_block
# preprocess
size
=
(
self
.
model
.
image_size
,)
*
2
videos
=
torch
.
cat
([
F
.
interpolate
(
u
.
transpose
(
0
,
1
)
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
torch
.
cat
([
F
.
interpolate
(
u
,
size
=
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
for
u
in
videos
])
videos
=
self
.
transforms
.
transforms
[
-
1
](
videos
.
mul_
(
0.5
).
add_
(
0.5
))
# forward
with
torch
.
amp
.
autocast
(
"cuda"
,
dtype
=
self
.
dtype
):
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
True
)
out
=
self
.
model
.
visual
(
videos
,
use_31_block
=
use_31_block
)
if
hasattr
(
args
,
"cpu_offload"
)
and
args
.
cpu_offload
:
self
.
to_cpu
()
...
...
@@ -466,51 +464,3 @@ class CLIPModel:
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
cpu
()
class
WanVideoIPHandler
:
def
__init__
(
self
,
model_name
,
repo_or_path
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
"cuda"
,
dtype
=
torch
.
float16
):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
"""720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
"""
image_encoder
=
CLIPVisionModel
.
from_pretrained
(
repo_or_path
,
torch_dtype
=
dtype
)
logger
.
info
(
f
"Using image encoder
{
model_name
}
from
{
repo_or_path
}
"
)
image_encoder
.
requires_grad_
(
require_grad
)
if
mode
==
"eval"
:
image_encoder
.
eval
()
else
:
image_encoder
.
train
()
self
.
dtype
=
dtype
self
.
device
=
device
self
.
image_encoder
=
image_encoder
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
size
=
(
224
,
224
)
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
self
.
normalize
=
T
.
Normalize
(
mean
=
mean
,
std
=
std
)
# self.image_processor = image_processor
def
encode
(
self
,
img_tensor
:
Tensor
,
):
if
img_tensor
.
ndim
==
5
:
# B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor
=
rearrange
(
img_tensor
,
"B C 1 H W -> B C H W"
)
img_tensor
=
torch
.
clamp
(
img_tensor
.
float
()
*
0.5
+
0.5
,
min
=
0.0
,
max
=
1.0
).
to
(
self
.
device
)
img_tensor
=
F
.
interpolate
(
img_tensor
,
size
=
self
.
size
,
mode
=
"bicubic"
,
align_corners
=
False
)
img_tensor
=
self
.
normalize
(
img_tensor
).
to
(
self
.
dtype
)
image_embeds
=
self
.
image_encoder
(
pixel_values
=
img_tensor
,
output_hidden_states
=
True
)
return
image_embeds
.
hidden_states
[
-
1
]
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
df68ed3f
...
...
@@ -10,29 +10,18 @@ from dataclasses import dataclass
from
lightx2v.utils.registry_factory
import
RUNNER_REGISTER
from
lightx2v.models.runners.wan.wan_runner
import
WanRunner
from
lightx2v.models.runners.default_runner
import
DefaultRunner
from
lightx2v.models.schedulers.wan.scheduler
import
WanScheduler
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.xlm_roberta.model
import
CLIPModel
,
WanVideoIPHandler
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.networks.wan.audio_adapter
import
AudioAdapter
,
AudioAdapterPipe
,
rank0_load_state_dict_from_path
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.models.schedulers.wan.step_distill.scheduler
import
WanStepDistillScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
EulerSchedulerTimestepFix
,
ConsistencyModelScheduler
from
lightx2v.models.schedulers.wan.audio.scheduler
import
ConsistencyModelScheduler
from
loguru
import
logger
import
torch.distributed
as
dist
from
einops
import
rearrange
import
torchaudio
as
ta
from
transformers
import
AutoFeatureExtractor
from
torchvision.datasets.folder
import
IMG_EXTENSIONS
from
torchvision.transforms
import
InterpolationMode
from
torchvision.transforms.functional
import
resize
...
...
@@ -618,12 +607,6 @@ class WanAudioRunner(WanRunner):
return
base_model
def
load_image_encoder
(
self
):
"""Load image encoder"""
clip_model_dir
=
self
.
config
[
"model_path"
]
+
"/image_encoder"
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
clip_model_dir
,
require_grad
=
False
,
mode
=
"eval"
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
def
run_image_encoder
(
self
,
config
,
vae_model
):
"""Run image encoder"""
...
...
@@ -638,7 +621,7 @@ class WanAudioRunner(WanRunner):
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
...
...
@@ -662,7 +645,7 @@ class WanAudioRunner(WanRunner):
# Resize image to target size
cond_frms
=
torch
.
nn
.
functional
.
interpolate
(
ref_img
,
size
=
(
config
.
tgt_h
,
config
.
tgt_w
),
mode
=
"bicubic"
)
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
# Prepare for VAE encoding
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
df68ed3f
...
...
@@ -197,7 +197,7 @@ class WanRunner(DefaultRunner):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
image_encoder
=
self
.
load_image_encoder
()
img
=
TF
.
to_tensor
(
img
).
sub_
(
0.5
).
div_
(
0.5
).
cuda
()
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
:,
None
,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
img
[
None
,
:,
:,
:]],
self
.
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
image_encoder
torch
.
cuda
.
empty_cache
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment