Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
c33c4896
Commit
c33c4896
authored
Sep 04, 2025
by
gushiqiao
Committed by
GitHub
Sep 04, 2025
Browse files
Update utils.py (#286)
parent
1767ff4b
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
3 additions
and
3 deletions
+3
-3
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+3
-3
No files found.
lightx2v/models/networks/wan/infer/utils.py
View file @
c33c4896
...
@@ -69,7 +69,7 @@ def apply_rotary_emb(x, freqs_i):
...
@@ -69,7 +69,7 @@ def apply_rotary_emb(x, freqs_i):
n
=
x
.
size
(
1
)
n
=
x
.
size
(
1
)
seq_len
=
freqs_i
.
size
(
0
)
seq_len
=
freqs_i
.
size
(
0
)
x_i
=
torch
.
view_as_complex
(
x
[:
seq_len
].
to
(
torch
.
float
64
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
x_i
=
torch
.
view_as_complex
(
x
[:
seq_len
].
to
(
torch
.
float
32
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
# Apply rotary embedding
# Apply rotary embedding
x_i
=
torch
.
view_as_real
(
x_i
*
freqs_i
).
flatten
(
2
)
x_i
=
torch
.
view_as_real
(
x_i
*
freqs_i
).
flatten
(
2
)
x_i
=
torch
.
cat
([
x_i
,
x
[
seq_len
:]])
x_i
=
torch
.
cat
([
x_i
,
x
[
seq_len
:]])
...
@@ -113,7 +113,7 @@ def rope_params(max_seq_len, dim, theta=10000):
...
@@ -113,7 +113,7 @@ def rope_params(max_seq_len, dim, theta=10000):
assert
dim
%
2
==
0
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
freqs
=
torch
.
outer
(
torch
.
arange
(
max_seq_len
),
torch
.
arange
(
max_seq_len
),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float
64
).
div
(
dim
)),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float
32
).
div
(
dim
)),
)
)
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
return
freqs
...
@@ -123,7 +123,7 @@ def sinusoidal_embedding_1d(dim, position):
...
@@ -123,7 +123,7 @@ def sinusoidal_embedding_1d(dim, position):
# preprocess
# preprocess
assert
dim
%
2
==
0
assert
dim
%
2
==
0
half
=
dim
//
2
half
=
dim
//
2
position
=
position
.
type
(
torch
.
float
64
)
position
=
position
.
type
(
torch
.
float
32
)
# calculation
# calculation
sinusoid
=
torch
.
outer
(
position
,
torch
.
pow
(
10000
,
-
torch
.
arange
(
half
).
to
(
position
).
div
(
half
)))
sinusoid
=
torch
.
outer
(
position
,
torch
.
pow
(
10000
,
-
torch
.
arange
(
half
).
to
(
position
).
div
(
half
)))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment