Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c550409c
Commit
c550409c
authored
Aug 15, 2025
by
wangshankun
Browse files
update ref timestep expand
parent
0b755a97
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
1 deletion
+18
-1
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
...2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
+18
-1
No files found.
lightx2v/models/networks/wan/infer/audio/pre_wan_audio_infer.py
View file @
c550409c
...
@@ -104,17 +104,34 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -104,17 +104,34 @@ class WanAudioPreInfer(WanPreInfer):
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
y
=
[
weights
.
patch_embedding
.
apply
(
u
.
unsqueeze
(
0
))
for
u
in
y
]
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
# y_grid_sizes = torch.stack([torch.tensor(u.shape[2:], dtype=torch.long) for u in y])
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
y
=
[
u
.
flatten
(
2
).
transpose
(
1
,
2
).
squeeze
(
0
)
for
u
in
y
]
ref_seq_lens
=
torch
.
tensor
([
u
.
size
(
0
)
for
u
in
y
],
dtype
=
torch
.
long
)
x
=
[
torch
.
cat
([
a
,
b
],
dim
=
0
)
for
a
,
b
in
zip
(
x
,
y
)]
x
=
[
torch
.
cat
([
a
,
b
],
dim
=
0
)
for
a
,
b
in
zip
(
x
,
y
)]
x
=
torch
.
stack
(
x
,
dim
=
0
)
x
=
torch
.
stack
(
x
,
dim
=
0
)
seq_len
=
x
[
0
].
size
(
0
)
if
self
.
config
.
model_cls
==
"wan2.2_audio"
:
bt
=
t
.
size
(
0
)
ref_seq_len
=
ref_seq_lens
[
0
].
item
()
t
=
torch
.
cat
(
[
t
,
torch
.
zeros
(
(
1
,
ref_seq_len
),
dtype
=
t
.
dtype
,
device
=
t
.
device
,
),
],
dim
=
1
,
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
# embed = weights.time_embedding_0.apply(embed)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
embed
=
weights
.
time_embedding_0
.
apply
(
embed
.
to
(
self
.
sensitive_layer_dtype
))
embed
=
weights
.
time_embedding_0
.
apply
(
embed
.
to
(
self
.
sensitive_layer_dtype
))
else
:
else
:
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
weights
.
time_embedding_0
.
apply
(
embed
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed
=
weights
.
time_embedding_2
.
apply
(
embed
)
embed
=
weights
.
time_embedding_2
.
apply
(
embed
)
embed0
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed0
=
torch
.
nn
.
functional
.
silu
(
embed
)
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
embed0
=
weights
.
time_projection_1
.
apply
(
embed0
).
unflatten
(
1
,
(
6
,
self
.
dim
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment