Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
2f874771
Commit
2f874771
authored
May 23, 2025
by
GoatWu
Browse files
bug fixed
parent
429dcc45
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
57 additions
and
20 deletions
+57
-20
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
...2v/models/networks/wan/infer/causvid/transformer_infer.py
+57
-20
No files found.
lightx2v/models/networks/wan/infer/causvid/transformer_infer.py
View file @
2f874771
...
@@ -90,16 +90,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -90,16 +90,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
kv_end
,
kv_end
,
)
)
return
x
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
block_idx
,
kv_start
,
kv_end
):
def
_infer_self_attn
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
block_idx
,
kv_start
,
kv_end
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 6, 1, dim
embed0
=
embed0
.
unsqueeze
(
0
)
#
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
embed0
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm1_out
=
(
norm1_out
*
(
1
+
embed0
[
1
])
+
embed0
[
0
]).
squeeze
(
0
)
norm1_out
=
(
norm1_out
*
(
1
+
embed0
[
1
])
+
embed0
[
0
]).
squeeze
(
0
)
...
@@ -120,7 +112,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -120,7 +112,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
self
.
kv_cache
[
block_idx
][
"k"
][
kv_start
:
kv_end
]
=
k
self
.
kv_cache
[
block_idx
][
"k"
][
kv_start
:
kv_end
]
=
k
self
.
kv_cache
[
block_idx
][
"v"
][
kv_start
:
kv_end
]
=
v
self
.
kv_cache
[
block_idx
][
"v"
][
kv_start
:
kv_end
]
=
v
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
=
q
,
k
=
self
.
kv_cache
[
block_idx
][
"k"
][:
kv_end
],
k_lens
=
torch
.
tensor
([
kv_end
],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
=
q
,
k_lens
=
torch
.
tensor
([
kv_end
],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
if
not
self
.
parallel_attention
:
if
not
self
.
parallel_attention
:
attn_out
=
weights
.
self_attn_1
.
apply
(
attn_out
=
weights
.
self_attn_1
.
apply
(
...
@@ -129,8 +121,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -129,8 +121,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v
=
self
.
kv_cache
[
block_idx
][
"v"
][:
kv_end
],
v
=
self
.
kv_cache
[
block_idx
][
"v"
][:
kv_end
],
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
l
q
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_kv
=
l
k
,
max_seqlen_kv
=
k
.
size
(
0
)
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
else
:
else
:
...
@@ -141,6 +133,9 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -141,6 +133,9 @@ class WanTransformerInferCausVid(WanTransformerInfer):
x
=
x
+
y
*
embed0
[
2
].
squeeze
(
0
)
x
=
x
+
y
*
embed0
[
2
].
squeeze
(
0
)
return
x
def
_infer_cross_attn
(
self
,
weights
,
x
,
context
,
block_idx
):
norm3_out
=
weights
.
norm3
.
apply
(
x
)
norm3_out
=
weights
.
norm3
.
apply
(
x
)
if
self
.
task
==
"i2v"
:
if
self
.
task
==
"i2v"
:
...
@@ -159,7 +154,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -159,7 +154,7 @@ class WanTransformerInferCausVid(WanTransformerInfer):
k
=
self
.
crossattn_cache
[
block_idx
][
"k"
]
k
=
self
.
crossattn_cache
[
block_idx
][
"k"
]
v
=
self
.
crossattn_cache
[
block_idx
][
"v"
]
v
=
self
.
crossattn_cache
[
block_idx
][
"v"
]
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
q
,
k
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
))
attn_out
=
weights
.
cross_attn_1
.
apply
(
attn_out
=
weights
.
cross_attn_1
.
apply
(
q
=
q
,
q
=
q
,
...
@@ -167,8 +162,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -167,8 +162,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v
=
v
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
l
q
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_kv
=
l
k
,
max_seqlen_kv
=
k
.
size
(
0
)
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
...
@@ -176,9 +171,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -176,9 +171,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
cu_seqlens_q
,
cu_seqlens_k
,
lq
,
lk
=
self
.
_calculate_q_k_len
(
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
q
,
k_img
,
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
)
...
@@ -188,8 +182,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -188,8 +182,8 @@ class WanTransformerInferCausVid(WanTransformerInfer):
v
=
v_img
,
v
=
v_img
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_k
,
cu_seqlens_kv
=
cu_seqlens_k
,
max_seqlen_q
=
l
q
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_kv
=
l
k
,
max_seqlen_kv
=
k
_img
.
size
(
0
)
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
...
@@ -198,9 +192,52 @@ class WanTransformerInferCausVid(WanTransformerInfer):
...
@@ -198,9 +192,52 @@ class WanTransformerInferCausVid(WanTransformerInfer):
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
x
=
x
+
attn_out
x
=
x
+
attn_out
return
x
def
_infer_ffn
(
self
,
weights
,
x
,
embed0
):
norm2_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
norm2_out
=
torch
.
nn
.
functional
.
layer_norm
(
x
,
(
x
.
shape
[
1
],),
None
,
None
,
1e-6
)
y
=
weights
.
ffn_0
.
apply
(
norm2_out
*
(
1
+
embed0
[
4
].
squeeze
(
0
))
+
embed0
[
3
].
squeeze
(
0
))
y
=
weights
.
ffn_0
.
apply
(
norm2_out
*
(
1
+
embed0
[
4
].
squeeze
(
0
))
+
embed0
[
3
].
squeeze
(
0
))
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
weights
.
ffn_2
.
apply
(
y
)
y
=
weights
.
ffn_2
.
apply
(
y
)
x
=
x
+
y
*
embed0
[
5
].
squeeze
(
0
)
x
=
x
+
y
*
embed0
[
5
].
squeeze
(
0
)
return
x
def
infer_block
(
self
,
weights
,
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
block_idx
,
kv_start
,
kv_end
):
if
embed0
.
dim
()
==
3
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
# 1, 6, 1, dim
embed0
=
embed0
.
unsqueeze
(
0
)
#
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
elif
embed0
.
dim
()
==
2
:
embed0
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
x
=
self
.
_infer_self_attn
(
weights
.
compute_phases
[
0
],
grid_sizes
,
embed
,
x
,
embed0
,
seq_lens
,
freqs
,
context
,
block_idx
,
kv_start
,
kv_end
)
x
=
self
.
_infer_cross_attn
(
weights
.
compute_phases
[
1
],
x
,
context
,
block_idx
)
x
=
self
.
_infer_ffn
(
weights
.
compute_phases
[
2
],
x
,
embed0
)
return
x
return
x
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment