Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
d02b97a7
".github/git@developer.sourcefind.cn:change/sglang.git" did not exist on "bb66cc4c52b1440a8e85247b706b2b3d645e902d"
Commit
d02b97a7
authored
Jun 26, 2025
by
wangshankun
Browse files
replace clip model
parent
e58dd9fe
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
87 additions
and
7 deletions
+87
-7
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
+67
-0
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+1
-0
lightx2v/models/runners/wan/wan_audio_runner.py
lightx2v/models/runners/wan/wan_audio_runner.py
+19
-7
No files found.
lightx2v/models/input_encoders/hf/xlm_roberta/model.py
View file @
d02b97a7
...
@@ -11,6 +11,9 @@ import torchvision.transforms as T
...
@@ -11,6 +11,9 @@ import torchvision.transforms as T
from
lightx2v.attentions
import
attention
from
lightx2v.attentions
import
attention
from
loguru
import
logger
from
loguru
import
logger
from
lightx2v.models.input_encoders.hf.q_linear
import
QuantLinearInt8
,
QuantLinearFp8
from
lightx2v.models.input_encoders.hf.q_linear
import
QuantLinearInt8
,
QuantLinearFp8
from
einops
import
rearrange
from
torch
import
Tensor
from
transformers
import
CLIPVisionModel
__all__
=
[
__all__
=
[
...
@@ -428,3 +431,67 @@ class CLIPModel:
...
@@ -428,3 +431,67 @@ class CLIPModel:
def
to_cpu
(
self
):
def
to_cpu
(
self
):
self
.
model
=
self
.
model
.
cpu
()
self
.
model
=
self
.
model
.
cpu
()
class
WanVideoIPHandler
:
def
__init__
(
self
,
model_name
,
repo_or_path
,
require_grad
=
False
,
mode
=
'eval'
,
device
=
'cuda'
,
dtype
=
torch
.
float16
):
# image_processor = CLIPImageProcessor.from_pretrained(
# repo_or_path, subfolder='image_processor')
''' 720P-I2V-diffusers config is
"size": {
"shortest_edge": 224
}
and 480P-I2V-diffusers config is
"size": {
"height": 224,
"width": 224
}
but Wan2.1 official use no_crop resize by default
so I don't use CLIPImageProcessor
'''
image_encoder
=
CLIPVisionModel
.
from_pretrained
(
repo_or_path
,
subfolder
=
'image_encoder'
,
torch_dtype
=
dtype
)
logger
.
info
(
f
'Using image encoder
{
model_name
}
from
{
repo_or_path
}
'
)
image_encoder
.
requires_grad_
(
require_grad
)
if
mode
==
'eval'
:
image_encoder
.
eval
()
else
:
image_encoder
.
train
()
self
.
dtype
=
dtype
self
.
device
=
device
self
.
image_encoder
=
image_encoder
.
to
(
device
=
device
,
dtype
=
dtype
)
self
.
size
=
(
224
,
224
)
mean
=
[
0.48145466
,
0.4578275
,
0.40821073
]
std
=
[
0.26862954
,
0.26130258
,
0.27577711
]
self
.
normalize
=
T
.
Normalize
(
mean
=
mean
,
std
=
std
)
# self.image_processor = image_processor
def
encode
(
self
,
img_tensor
:
Tensor
,
):
if
img_tensor
.
ndim
==
5
:
# B C T H W
# img_tensor = img_tensor[:, :, 0]
img_tensor
=
rearrange
(
img_tensor
,
"B C 1 H W -> B C H W"
)
img_tensor
=
torch
.
clamp
(
img_tensor
.
float
()
*
0.5
+
0.5
,
min
=
0.0
,
max
=
1.0
).
to
(
self
.
device
)
img_tensor
=
F
.
interpolate
(
img_tensor
,
size
=
self
.
size
,
mode
=
'bicubic'
,
align_corners
=
False
)
img_tensor
=
self
.
normalize
(
img_tensor
).
to
(
self
.
dtype
)
logger
.
info
(
f
'Image tensor shape after processing:
{
img_tensor
}
'
)
image_embeds
=
self
.
image_encoder
(
pixel_values
=
img_tensor
,
output_hidden_states
=
True
)
logger
.
info
(
f
'Image embeds :
{
image_embeds
.
hidden_states
[
-
1
]
}
'
)
return
image_embeds
.
hidden_states
[
-
1
]
\ No newline at end of file
lightx2v/models/networks/wan/audio_adapter.py
View file @
d02b97a7
...
@@ -376,6 +376,7 @@ class AudioAdapterPipe:
...
@@ -376,6 +376,7 @@ class AudioAdapterPipe:
self
.
device
=
device
self
.
device
=
device
self
.
generator
=
generator
self
.
generator
=
generator
self
.
audio_encoder_dtype
=
torch
.
float16
self
.
audio_encoder_dtype
=
torch
.
float16
##音频编码器
self
.
audio_encoder
=
AutoModel
.
from_pretrained
(
audio_encoder_repo
)
self
.
audio_encoder
=
AutoModel
.
from_pretrained
(
audio_encoder_repo
)
self
.
audio_encoder
.
eval
()
self
.
audio_encoder
.
eval
()
...
...
lightx2v/models/runners/wan/wan_audio_runner.py
View file @
d02b97a7
...
@@ -11,7 +11,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
...
@@ -11,7 +11,7 @@ from lightx2v.models.schedulers.wan.scheduler import WanScheduler
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.models.networks.wan.model
import
WanModel
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.utils.profiler
import
ProfilingContext4Debug
,
ProfilingContext
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.t5.model
import
T5EncoderModel
from
lightx2v.models.input_encoders.hf.xlm_roberta.model
import
CLIPModel
from
lightx2v.models.input_encoders.hf.xlm_roberta.model
import
CLIPModel
,
WanVideoIPHandler
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.audio_model
import
WanAudioModel
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.networks.wan.lora_adapter
import
WanLoraWrapper
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
from
lightx2v.models.video_encoders.hf.wan.vae
import
WanVAE
...
@@ -244,7 +244,8 @@ class WanAudioRunner(WanRunner):
...
@@ -244,7 +244,8 @@ class WanAudioRunner(WanRunner):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
def
load_audio_models
(
self
):
def
load_audio_models
(
self
):
self
.
audio_encoder
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
##音频特征提取器
self
.
audio_preprocess
=
AutoFeatureExtractor
.
from_pretrained
(
self
.
config
[
"model_path"
],
subfolder
=
"audio_encoder"
)
audio_adaper
=
AudioAdapter
.
from_transformer
(
audio_adaper
=
AudioAdapter
.
from_transformer
(
self
.
model
,
self
.
model
,
audio_feature_dim
=
1024
,
audio_feature_dim
=
1024
,
...
@@ -265,6 +266,18 @@ class WanAudioRunner(WanRunner):
...
@@ -265,6 +266,18 @@ class WanAudioRunner(WanRunner):
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
base_model
=
WanAudioModel
(
self
.
config
.
model_path
,
self
.
config
,
self
.
init_device
)
return
base_model
return
base_model
def
load_image_encoder
(
self
):
image_encoder
=
WanVideoIPHandler
(
"CLIPModel"
,
repo_or_path
=
"/mnt/aigc/zoemodels/Wan21/Wan2.1-I2V-14B-720P-Diffusers"
,
require_grad
=
False
,
mode
=
'eval'
,
device
=
self
.
init_device
,
dtype
=
torch
.
float16
)
return
image_encoder
def
run_image_encoder
(
self
,
config
,
vae_model
):
def
run_image_encoder
(
self
,
config
,
vae_model
):
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
Image
.
open
(
config
.
image_path
)
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
ref_img
=
(
np
.
array
(
ref_img
).
astype
(
np
.
float32
)
-
127.5
)
/
127.5
...
@@ -276,8 +289,7 @@ class WanAudioRunner(WanRunner):
...
@@ -276,8 +289,7 @@ class WanAudioRunner(WanRunner):
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
cond_frms
,
tgt_h
,
tgt_w
=
adaptive_resize
(
ref_img
)
config
.
tgt_h
=
tgt_h
config
.
tgt_h
=
tgt_h
config
.
tgt_w
=
tgt_w
config
.
tgt_w
=
tgt_w
clip_encoder_out
=
self
.
image_encoder
.
encode
(
cond_frms
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
clip_encoder_out
=
self
.
image_encoder
.
visual
([
cond_frms
.
squeeze
(
0
)[:,
None
,
:,
:]],
config
).
squeeze
(
0
).
to
(
torch
.
bfloat16
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
cond_frms
=
rearrange
(
cond_frms
,
"1 C H W -> 1 C 1 H W"
)
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
lat_h
,
lat_w
=
tgt_h
//
8
,
tgt_w
//
8
...
@@ -393,7 +405,7 @@ class WanAudioRunner(WanRunner):
...
@@ -393,7 +405,7 @@ class WanAudioRunner(WanRunner):
if
expected_frames
<
max_num_frames
:
if
expected_frames
<
max_num_frames
:
useful_length
=
audio_array
.
shape
[
0
]
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_
encoder
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
audio_input_feat
=
self
.
audio_
preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
elif
res_frame_num
>
5
and
idx
==
interval_num
-
1
:
# 最后一段可能不够81帧
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
...
@@ -404,7 +416,7 @@ class WanAudioRunner(WanRunner):
...
@@ -404,7 +416,7 @@ class WanAudioRunner(WanRunner):
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
useful_length
=
audio_array
.
shape
[
0
]
useful_length
=
audio_array
.
shape
[
0
]
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_array
=
np
.
concatenate
((
audio_array
,
np
.
zeros
(
max_num_audio_length
)[:
max_num_audio_length
-
useful_length
]),
axis
=
0
)
audio_input_feat
=
self
.
audio_
encoder
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
audio_input_feat
=
self
.
audio_
preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
else
:
# 中间段满81帧带pre_latens
else
:
# 中间段满81帧带pre_latens
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
prev_frames
=
torch
.
zeros
((
1
,
3
,
max_num_frames
,
tgt_h
,
tgt_w
),
device
=
device
)
...
@@ -413,7 +425,7 @@ class WanAudioRunner(WanRunner):
...
@@ -413,7 +425,7 @@ class WanAudioRunner(WanRunner):
prev_len
=
prev_token_length
prev_len
=
prev_token_length
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_start
,
audio_end
=
get_audio_range
(
idx
*
max_num_frames
-
idx
*
prev_frame_length
,
(
idx
+
1
)
*
max_num_frames
-
idx
*
prev_frame_length
,
fps
=
target_fps
,
audio_sr
=
audio_sr
)
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
audio_array
=
audio_array_ori
[
audio_start
:
audio_end
]
audio_input_feat
=
self
.
audio_
encoder
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
audio_input_feat
=
self
.
audio_
preprocess
(
audio_array
,
sampling_rate
=
audio_sr
,
return_tensors
=
"pt"
).
input_values
.
squeeze
(
0
)
self
.
inputs
[
"audio_encoder_output"
]
=
audio_input_feat
.
to
(
device
)
self
.
inputs
[
"audio_encoder_output"
]
=
audio_input_feat
.
to
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment