Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
820b4450
Commit
820b4450
authored
Aug 06, 2025
by
gushiqiao
Committed by
GitHub
Aug 06, 2025
Browse files
Fix fp32-related bug in audio model
Fix fp32-related bug in audio model
parents
a3d0f2d9
4389450a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
40 additions
and
9 deletions
+40
-9
lightx2v/models/networks/wan/audio_adapter.py
lightx2v/models/networks/wan/audio_adapter.py
+1
-1
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
...v/models/networks/wan/infer/audio/post_wan_audio_infer.py
+19
-4
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+20
-4
No files found.
lightx2v/models/networks/wan/audio_adapter.py
100644 → 100755
View file @
820b4450
...
@@ -121,7 +121,7 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -121,7 +121,7 @@ class PerceiverAttentionCA(nn.Module):
x
=
self
.
norm_kv
(
x
)
x
=
self
.
norm_kv
(
x
)
shift
,
scale
,
gate
=
(
t_emb
+
self
.
shift_scale_gate
).
chunk
(
3
,
dim
=
1
)
shift
,
scale
,
gate
=
(
t_emb
+
self
.
shift_scale_gate
).
chunk
(
3
,
dim
=
1
)
latents
=
self
.
norm_q
(
latents
)
*
(
1
+
scale
)
+
shift
latents
=
self
.
norm_q
(
latents
)
*
(
1
+
scale
)
+
shift
q
=
self
.
to_q
(
latents
)
q
=
self
.
to_q
(
latents
.
to
(
GET_DTYPE
())
)
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
k
,
v
=
self
.
to_kv
(
x
).
chunk
(
2
,
dim
=-
1
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
k
=
rearrange
(
k
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
k
=
rearrange
(
k
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
...
...
lightx2v/models/networks/wan/infer/audio/post_wan_audio_infer.py
View file @
820b4450
...
@@ -3,16 +3,21 @@ import math
...
@@ -3,16 +3,21 @@ import math
import
torch
import
torch
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.models.networks.wan.infer.post_infer
import
WanPostInfer
from
lightx2v.utils.envs
import
*
class
WanAudioPostInfer
(
WanPostInfer
):
class
WanAudioPostInfer
(
WanPostInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
self
.
out_dim
=
config
[
"out_dim"
]
self
.
out_dim
=
config
[
"out_dim"
]
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
,
valid_patch_length
):
def
infer
(
self
,
weights
,
x
,
e
,
grid_sizes
,
valid_patch_length
):
if
e
.
dim
()
==
2
:
if
e
.
dim
()
==
2
:
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
modulation
=
weights
.
head_modulation
.
tensor
# 1, 2, dim
...
@@ -22,13 +27,23 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -22,13 +27,23 @@ class WanAudioPostInfer(WanPostInfer):
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
(
modulation
+
e
.
unsqueeze
(
1
)).
chunk
(
2
,
dim
=
1
)
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
e
=
[
ei
.
squeeze
(
1
)
for
ei
in
e
]
norm_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
).
type_as
(
x
)
x
=
weights
.
norm
.
apply
(
x
)
out
=
norm_out
*
(
1
+
e
[
1
].
squeeze
(
0
))
+
e
[
0
].
squeeze
(
0
)
x
=
weights
.
head
.
apply
(
out
)
x
=
x
[:,
:
valid_patch_length
]
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
x
.
mul_
(
1
+
e
[
1
].
squeeze
()).
add_
(
e
[
0
].
squeeze
())
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
infer_dtype
)
x
=
weights
.
head
.
apply
(
x
)
x
=
x
[:,
:
valid_patch_length
]
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
if
self
.
clean_cuda_cache
:
del
e
,
grid_sizes
torch
.
cuda
.
empty_cache
()
return
[
u
.
float
()
for
u
in
x
]
return
[
u
.
float
()
for
u
in
x
]
def
unpatchify
(
self
,
x
,
grid_sizes
):
def
unpatchify
(
self
,
x
,
grid_sizes
):
...
...
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
820b4450
import
torch
import
torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
...
@@ -23,6 +24,8 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -23,6 +24,8 @@ class WanAudioPreInfer(WanPreInfer):
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
text_len
=
config
[
"text_len"
]
self
.
text_len
=
config
[
"text_len"
]
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
infer
(
self
,
weights
,
inputs
,
positive
):
def
infer
(
self
,
weights
,
inputs
,
positive
):
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
prev_latents
=
inputs
[
"previmg_encoder_output"
][
"prev_latents"
].
unsqueeze
(
0
)
...
@@ -65,6 +68,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -65,6 +68,7 @@ class WanAudioPreInfer(WanPreInfer):
)
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
1
)
ref_image_encoder
=
torch
.
concat
([
ref_image_encoder
,
zero_padding
],
dim
=
1
)
y
=
list
(
torch
.
unbind
(
ref_image_encoder
,
dim
=
0
))
# 第一个batch维度变成list
y
=
list
(
torch
.
unbind
(
ref_image_encoder
,
dim
=
0
))
# 第一个batch维度变成list
# embeddings
# embeddings
x
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
x
]
x
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
x
]
x_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
x
])
x_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
x
])
...
@@ -74,28 +78,40 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -74,28 +78,40 @@ class WanAudioPreInfer(WanPreInfer):
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
x
=
torch
.
cat
([
torch
.
cat
([
u
,
u
.
new_zeros
(
1
,
seq_len
-
u
.
size
(
1
),
u
.
size
(
2
))],
dim
=
1
)
for
u
in
x
])
valid_patch_length
=
x
[
0
].
size
(
0
)
valid_patch_length
=
x
[
0
].
size
(
0
)
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
y_grid_sizes
=
torch
.
stack
([
torch
.
tensor
(
u
.
shape
[
2
:],
dtype
=
torch
.
long
)
for
u
in
y
])
#
y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
x
=
[
torch
.
cat
([
a
,
b
],
dim
=
0
)
for
a
,
b
in
zip
(
x
,
y
)]
x
=
[
torch
.
cat
([
a
,
b
],
dim
=
0
)
for
a
,
b
in
zip
(
x
,
y
)]
x
=
torch
.
stack
(
x
,
dim
=
0
)
x
=
torch
.
stack
(
x
,
dim
=
0
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
# embed = weights.time_embedding_0.apply(embed)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
embed
=
weights
.
time_embedding_0
.
apply
(
embed
.
to
(
self
.
sensitive_layer_dtype
))
else
:
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed
=
weights
.
time_embedding_2
.
apply
(
embed
)
embed
=
weights
.
time_embedding_2
.
apply
(
embed
)
embed0
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed0
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
# text embeddings
# text embeddings
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
stacked
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
self
.
text_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
context
])
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
else
:
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
weights
.
text_embedding_0
.
apply
(
stacked
.
squeeze
(
0
))
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
out
=
torch
.
nn
.
functional
.
gelu
(
out
,
approximate
=
"tanh"
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
context
=
weights
.
text_embedding_2
.
apply
(
out
)
if
self
.
clean_cuda_cache
:
del
out
,
stacked
torch
.
cuda
.
empty_cache
()
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
==
"i2v"
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
context_clip
=
weights
.
proj_0
.
apply
(
clip_fea
)
if
self
.
clean_cuda_cache
:
del
clip_fea
torch
.
cuda
.
empty_cache
()
context_clip
=
weights
.
proj_1
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_1
.
apply
(
context_clip
)
context_clip
=
torch
.
nn
.
functional
.
gelu
(
context_clip
,
approximate
=
"none"
)
context_clip
=
torch
.
nn
.
functional
.
gelu
(
context_clip
,
approximate
=
"none"
)
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
context_clip
=
weights
.
proj_3
.
apply
(
context_clip
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment