Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
091a2a85
Commit
091a2a85
authored
Sep 11, 2025
by
sandy
Committed by
GitHub
Sep 11, 2025
Browse files
[Feat] For Sekotalk Add Torch Compile (#294)
parent
accbf710
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
115 additions
and
77 deletions
+115
-77
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
...tx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
+71
-21
lightx2v/models/networks/wan/infer/audio/post_infer.py
lightx2v/models/networks/wan/infer/audio/post_infer.py
+5
-2
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+4
-3
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
...tx2v/models/networks/wan/infer/audio/transformer_infer.py
+13
-32
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+7
-1
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/infer/post_infer.py
lightx2v/models/networks/wan/infer/post_infer.py
+5
-9
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+5
-4
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+1
-1
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+3
-3
No files found.
lightx2v/models/input_encoders/hf/seko_audio/audio_adapter.py
View file @
091a2a85
...
@@ -19,30 +19,79 @@ def linear_interpolation(features, output_len: int):
...
@@ -19,30 +19,79 @@ def linear_interpolation(features, output_len: int):
return
output_features
.
transpose
(
1
,
2
)
return
output_features
.
transpose
(
1
,
2
)
def
get_q_lens_audio_range
(
@
torch
.
compiler
.
disable
batchsize
:
int
,
def
get_max_int
(
q_lens
,
k_lens
):
n_tokens_per_rank
:
int
,
max_seqlen_q
=
int
(
q_lens
.
max
().
item
())
n_query_tokens
:
int
,
max_seqlen_k
=
int
(
k_lens
.
max
().
item
())
n_tokens_per_frame
:
int
,
return
max_seqlen_q
,
max_seqlen_k
sp_rank
:
int
,
def
get_qk_lens_audio_range
(
n_tokens_per_rank
:
torch
.
Tensor
,
n_query_tokens
:
torch
.
Tensor
,
n_tokens_per_frame
:
torch
.
Tensor
,
sp_rank
:
torch
.
Tensor
,
num_tokens_x4
,
):
):
device
=
n_tokens_per_rank
.
device
dtype
=
torch
.
int32
if
n_query_tokens
==
0
:
if
n_query_tokens
==
0
:
q_lens
=
[
1
]
*
batchsize
q_lens
=
torch
.
ones
(
1
,
dtype
=
dtype
,
device
=
device
)
return
q_lens
,
0
,
1
t0
=
torch
.
tensor
(
0
,
device
=
device
)
t1
=
torch
.
tensor
(
1
,
device
=
device
)
k_lens
=
torch
.
full
((
t1
-
t0
,),
num_tokens_x4
,
dtype
=
dtype
,
device
=
device
)
max_seqlen_q
,
max_seqlen_k
=
get_max_int
(
q_lens
,
k_lens
)
return
q_lens
,
k_lens
,
max_seqlen_q
,
max_seqlen_k
,
t0
,
t1
idx0
=
n_tokens_per_rank
*
sp_rank
idx0
=
n_tokens_per_rank
*
sp_rank
first_length
=
n_tokens_per_frame
-
idx0
%
n_tokens_per_frame
first_length
=
n_tokens_per_frame
-
idx0
%
n_tokens_per_frame
first_length
=
min
(
first_length
,
n_query_tokens
)
first_length
=
torch
.
minimum
(
first_length
,
n_query_tokens
)
n_frames
=
(
n_query_tokens
-
first_length
)
//
n_tokens_per_frame
n_frames
=
torch
.
div
(
n_query_tokens
-
first_length
,
n_tokens_per_frame
,
rounding_mode
=
"floor"
)
last_length
=
n_query_tokens
-
n_frames
*
n_tokens_per_frame
-
first_length
last_length
=
n_query_tokens
-
n_frames
*
n_tokens_per_frame
-
first_length
q_lens
=
[]
if
first_length
>
0
:
first_tensor
=
first_length
.
unsqueeze
(
0
)
# [1]
q_lens
.
append
(
first_length
)
frame_tensor
=
n_tokens_per_frame
.
repeat
(
n_frames
)
# [n_frames]
q_lens
+=
[
n_tokens_per_frame
]
*
n_frames
last_tensor
=
last_length
.
unsqueeze
(
0
)
# [1]
if
last_length
>
0
:
q_lens
.
append
(
last_length
)
q_lens_all
=
torch
.
cat
([
first_tensor
,
frame_tensor
,
last_tensor
])
q_lens
=
q_lens_all
[
q_lens_all
>
0
].
to
(
dtype
)
t0
=
idx0
//
n_tokens_per_frame
t0
=
idx0
//
n_tokens_per_frame
t1
=
t0
+
len
(
q_lens
)
t1
=
t0
+
q_lens
.
numel
()
return
q_lens
*
batchsize
,
t0
,
t1
k_lens
=
torch
.
full
((
t1
-
t0
,),
num_tokens_x4
,
dtype
=
dtype
,
device
=
device
)
assert
q_lens
.
shape
==
k_lens
.
shape
max_seqlen_q
,
max_seqlen_k
=
get_max_int
(
q_lens
,
k_lens
)
return
q_lens
,
k_lens
,
max_seqlen_q
,
max_seqlen_k
,
t0
,
t1
def
calculate_n_query_tokens
(
hidden_states
,
sp_rank
,
sp_size
,
n_tokens_per_rank
,
n_tokens
):
tail_length
=
n_tokens_per_rank
*
sp_size
-
n_tokens
n_unused_ranks
=
tail_length
//
n_tokens_per_rank
if
sp_rank
>
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
0
elif
sp_rank
==
sp_size
-
n_unused_ranks
-
1
:
val
=
n_tokens_per_rank
-
(
tail_length
%
n_tokens_per_rank
)
n_query_tokens
=
val
else
:
n_query_tokens
=
n_tokens_per_rank
if
n_query_tokens
>
0
:
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
hidden_states_tail
=
hidden_states
[:,
n_query_tokens
:]
else
:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_tail
=
hidden_states
[:,
1
:]
return
n_query_tokens
,
hidden_states_aligned
,
hidden_states_tail
class
PerceiverAttentionCA
(
nn
.
Module
):
class
PerceiverAttentionCA
(
nn
.
Module
):
...
@@ -73,7 +122,7 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -73,7 +122,7 @@ class PerceiverAttentionCA(nn.Module):
shift_scale_gate
[:,
2
]
=
1
shift_scale_gate
[:,
2
]
=
1
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
self
.
register_buffer
(
"shift_scale_gate"
,
shift_scale_gate
,
persistent
=
False
)
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
):
def
forward
(
self
,
x
,
latents
,
t_emb
,
q_lens
,
k_lens
,
max_seqlen_q
,
max_seqlen_k
):
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
"""x shape (batchsize, latent_frame, audio_tokens_per_latent,
model_dim) latents (batchsize, length, model_dim)"""
model_dim) latents (batchsize, length, model_dim)"""
batchsize
=
len
(
x
)
batchsize
=
len
(
x
)
...
@@ -90,14 +139,15 @@ class PerceiverAttentionCA(nn.Module):
...
@@ -90,14 +139,15 @@ class PerceiverAttentionCA(nn.Module):
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
q
=
rearrange
(
q
,
"B L (H C) -> (B L) H C"
,
H
=
self
.
heads
)
k
=
rearrange
(
k
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
k
=
rearrange
(
k
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
v
=
rearrange
(
v
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
v
=
rearrange
(
v
,
"B T L (H C) -> (B T L) H C"
,
H
=
self
.
heads
)
out
=
flash_attn
.
flash_attn_varlen_func
(
out
=
flash_attn
.
flash_attn_varlen_func
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
).
to
(
q
.
device
,
non_blocking
=
True
),
max_seqlen_q
=
q_lens
.
max
().
item
()
,
max_seqlen_q
=
max_seqlen_q
,
max_seqlen_k
=
k_lens
.
max
().
item
()
,
max_seqlen_k
=
max_seqlen_k
,
dropout_p
=
0.0
,
dropout_p
=
0.0
,
softmax_scale
=
None
,
softmax_scale
=
None
,
causal
=
False
,
causal
=
False
,
...
...
lightx2v/models/networks/wan/infer/audio/post_infer.py
View file @
091a2a85
...
@@ -11,8 +11,11 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -11,8 +11,11 @@ class WanAudioPostInfer(WanPostInfer):
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
@
torch
.
compile
(
disable
=
not
CHECK_ENABLE_GRAPH_MODE
())
def
infer
(
self
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]]
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]]
pre_infer_out
.
grid_sizes
[:,
0
]
-=
1
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
t
,
h
,
w
=
pre_infer_out
.
grid_sizes
.
tuple
grid_sizes
=
(
t
-
1
,
h
,
w
)
x
=
self
.
unpatchify
(
x
,
grid_sizes
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
091a2a85
...
@@ -3,7 +3,7 @@ import torch
...
@@ -3,7 +3,7 @@ import torch
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..module_io
import
WanPreInferModuleOutput
from
..module_io
import
GridOutput
,
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
...
@@ -61,9 +61,9 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -61,9 +61,9 @@ class WanAudioPreInfer(WanPreInfer):
# embeddings
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
long
).
unsqueeze
(
0
)
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
(
).
unsqueeze
(
0
)
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
...
@@ -114,6 +114,7 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -114,6 +114,7 @@ class WanAudioPreInfer(WanPreInfer):
del
context_clip
del
context_clip
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
grid_sizes
=
GridOutput
(
tensor
=
grid_sizes
,
tuple
=
(
grid_sizes
[
0
][
0
].
item
(),
grid_sizes
[
0
][
1
].
item
(),
grid_sizes
[
0
][
2
].
item
()))
return
WanPreInferModuleOutput
(
return
WanPreInferModuleOutput
(
embed
=
embed
,
embed
=
embed
,
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
...
...
lightx2v/models/networks/wan/infer/audio/transformer_infer.py
View file @
091a2a85
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
lightx2v.models.input_encoders.hf.seko_audio.audio_adapter
import
get_q_lens_audio_range
from
lightx2v.models.input_encoders.hf.seko_audio.audio_adapter
import
calculate_n_query_tokens
,
get_q
k
_lens_audio_range
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
from
lightx2v.models.networks.wan.infer.offload.transformer_infer
import
WanOffloadTransformerInfer
...
@@ -18,11 +18,9 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
...
@@ -18,11 +18,9 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
audio_grid_sizes
=
[
row
.
clone
()
for
row
in
pre_infer_out
.
grid_sizes
]
audio_grid_sizes
[
0
][
0
]
-=
1
x
=
self
.
modify_hidden_states
(
x
=
self
.
modify_hidden_states
(
hidden_states
=
x
.
to
(
self
.
infer_dtype
),
hidden_states
=
x
.
to
(
self
.
infer_dtype
),
grid_sizes
=
audio_
grid_sizes
,
grid_sizes
=
pre_infer_out
.
grid_sizes
.
tensor
,
ca_block
=
self
.
audio_adapter
.
ca
[
self
.
block_idx
],
ca_block
=
self
.
audio_adapter
.
ca
[
self
.
block_idx
],
audio_encoder_output
=
pre_infer_out
.
adapter_output
[
"audio_encoder_output"
],
audio_encoder_output
=
pre_infer_out
.
adapter_output
[
"audio_encoder_output"
],
t_emb
=
self
.
scheduler
.
audio_adapter_t_emb
,
t_emb
=
self
.
scheduler
.
audio_adapter_t_emb
,
...
@@ -41,11 +39,14 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
...
@@ -41,11 +39,14 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
"""
"""
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
if
len
(
hidden_states
.
shape
)
==
2
:
# 扩展batchsize dim
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
hidden_states
=
hidden_states
.
unsqueeze
(
0
)
# bs = 1
t
,
h
,
w
=
grid_sizes
[
0
].
tolist
()
n_tokens
=
t
*
h
*
w
total_tokens
=
grid_sizes
[
0
].
prod
()
pre_frame_tokens
=
grid_sizes
[
0
][
1
:].
prod
()
n_tokens
=
total_tokens
-
pre_frame_tokens
# 去掉ref image的token数
ori_dtype
=
hidden_states
.
dtype
ori_dtype
=
hidden_states
.
dtype
device
=
hidden_states
.
device
device
=
hidden_states
.
device
bs
,
n_tokens_per_rank
=
hidden_states
.
s
hape
[:
2
]
n_tokens_per_rank
=
torch
.
tensor
(
hidden_states
.
s
ize
(
1
),
dtype
=
torch
.
int32
,
device
=
device
)
if
seq_p_group
is
not
None
:
if
seq_p_group
is
not
None
:
sp_size
=
dist
.
get_world_size
(
seq_p_group
)
sp_size
=
dist
.
get_world_size
(
seq_p_group
)
...
@@ -54,35 +55,15 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
...
@@ -54,35 +55,15 @@ class WanAudioTransformerInfer(WanOffloadTransformerInfer):
sp_size
=
1
sp_size
=
1
sp_rank
=
0
sp_rank
=
0
tail_length
=
n_tokens_per_rank
*
sp_size
-
n_tokens
n_query_tokens
,
hidden_states_aligned
,
hidden_states_tail
=
calculate_n_query_tokens
(
hidden_states
,
sp_rank
,
sp_size
,
n_tokens_per_rank
,
n_tokens
)
n_unused_ranks
=
tail_length
//
n_tokens_per_rank
if
sp_rank
>
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
0
elif
sp_rank
==
sp_size
-
n_unused_ranks
-
1
:
n_query_tokens
=
n_tokens_per_rank
-
tail_length
%
n_tokens_per_rank
else
:
n_query_tokens
=
n_tokens_per_rank
if
n_query_tokens
>
0
:
hidden_states_aligned
=
hidden_states
[:,
:
n_query_tokens
]
hidden_states_tail
=
hidden_states
[:,
n_query_tokens
:]
else
:
# for ranks that should be excluded from cross-attn, fake cross-attn will be applied so that FSDP works.
hidden_states_aligned
=
hidden_states
[:,
:
1
]
hidden_states_tail
=
hidden_states
[:,
1
:]
q_lens
,
t0
,
t1
=
get_q_lens_audio_range
(
batchsize
=
bs
,
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
h
*
w
,
sp_rank
=
sp_rank
)
q_lens
,
k_lens
,
max_seqlen_q
,
max_seqlen_k
,
t0
,
t1
=
get_qk_lens_audio_range
(
q_lens
=
torch
.
tensor
(
q_lens
,
device
=
device
,
dtype
=
torch
.
int32
)
n_tokens_per_rank
=
n_tokens_per_rank
,
n_query_tokens
=
n_query_tokens
,
n_tokens_per_frame
=
pre_frame_tokens
,
sp_rank
=
sp_rank
,
num_tokens_x4
=
self
.
num_tokens_x4
"""
)
processing audio features in sp_state can be moved outside.
"""
audio_encoder_output
=
audio_encoder_output
[:,
t0
:
t1
]
k_lens
=
torch
.
tensor
([
self
.
num_tokens_x4
]
*
(
t1
-
t0
)
*
bs
,
device
=
device
,
dtype
=
torch
.
int32
)
assert
q_lens
.
shape
==
k_lens
.
shape
# ca_block:CrossAttention函数
# ca_block:CrossAttention函数
if
self
.
audio_adapter
.
cpu_offload
:
if
self
.
audio_adapter
.
cpu_offload
:
ca_block
.
to
(
"cuda"
)
ca_block
.
to
(
"cuda"
)
residual
=
ca_block
(
audio_encoder_output
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
)
*
weight
residual
=
ca_block
(
audio_encoder_output
[:,
t0
:
t1
]
,
hidden_states_aligned
,
t_emb
,
q_lens
,
k_lens
,
max_seqlen_q
,
max_seqlen_k
)
*
weight
if
self
.
audio_adapter
.
cpu_offload
:
if
self
.
audio_adapter
.
cpu_offload
:
ca_block
.
to
(
"cpu"
)
ca_block
.
to
(
"cpu"
)
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
residual
=
residual
.
to
(
ori_dtype
)
# audio做了CrossAttention之后以Residual的方式注入
...
...
lightx2v/models/networks/wan/infer/module_io.py
View file @
091a2a85
...
@@ -4,10 +4,16 @@ from typing import Any, Dict
...
@@ -4,10 +4,16 @@ from typing import Any, Dict
import
torch
import
torch
@
dataclass
class
GridOutput
:
tensor
:
torch
.
Tensor
tuple
:
tuple
@
dataclass
@
dataclass
class
WanPreInferModuleOutput
:
class
WanPreInferModuleOutput
:
embed
:
torch
.
Tensor
embed
:
torch
.
Tensor
grid_sizes
:
torch
.
Tensor
grid_sizes
:
GridOutput
x
:
torch
.
Tensor
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
091a2a85
...
@@ -188,7 +188,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -188,7 +188,7 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
)
=
self
.
pre_process
(
cur_phase
.
modulation
,
pre_infer_out
.
embed0
)
)
=
self
.
pre_process
(
cur_phase
.
modulation
,
pre_infer_out
.
embed0
)
self
.
phase_params
[
"y_out"
]
=
self
.
infer_self_attn
(
self
.
phase_params
[
"y_out"
]
=
self
.
infer_self_attn
(
cur_phase
,
cur_phase
,
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
grid_sizes
.
tuple
,
x
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
freqs
,
...
...
lightx2v/models/networks/wan/infer/post_infer.py
View file @
091a2a85
...
@@ -15,7 +15,7 @@ class WanPostInfer:
...
@@ -15,7 +15,7 @@ class WanPostInfer:
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
def
infer
(
self
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
)
x
=
self
.
unpatchify
(
x
,
pre_infer_out
.
grid_sizes
.
tuple
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -23,12 +23,8 @@ class WanPostInfer:
...
@@ -23,12 +23,8 @@ class WanPostInfer:
return
[
u
.
float
()
for
u
in
x
]
return
[
u
.
float
()
for
u
in
x
]
def
unpatchify
(
self
,
x
,
grid_sizes
):
def
unpatchify
(
self
,
x
,
grid_sizes
):
x
=
x
.
unsqueeze
(
0
)
c
=
self
.
out_dim
c
=
self
.
out_dim
out
=
[]
x
=
x
[:
math
.
prod
(
grid_sizes
)].
view
(
*
grid_sizes
,
*
self
.
patch_size
,
c
)
for
u
,
v
in
zip
(
x
,
grid_sizes
.
tolist
()):
x
=
torch
.
einsum
(
"fhwpqrc->cfphqwr"
,
x
)
u
=
u
[:
math
.
prod
(
v
)].
view
(
*
v
,
*
self
.
patch_size
,
c
)
x
=
x
.
reshape
(
c
,
*
[
i
*
j
for
i
,
j
in
zip
(
grid_sizes
,
self
.
patch_size
)])
u
=
torch
.
einsum
(
"fhwpqrc->cfphqwr"
,
u
)
return
[
x
]
u
=
u
.
reshape
(
c
,
*
[
i
*
j
for
i
,
j
in
zip
(
v
,
self
.
patch_size
)])
out
.
append
(
u
)
return
out
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
091a2a85
...
@@ -2,7 +2,7 @@ import torch
...
@@ -2,7 +2,7 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.module_io
import
WanPreInferModuleOutput
from
.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
...
@@ -61,13 +61,13 @@ class WanPreInfer:
...
@@ -61,13 +61,13 @@ class WanPreInfer:
# embeddings
# embeddings
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
long
).
unsqueeze
(
0
)
grid_sizes
=
torch
.
tensor
(
x
.
shape
[
2
:],
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
long
).
cuda
(
).
unsqueeze
(
0
)
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
if
self
.
enable_dynamic_cfg
:
s
=
torch
.
tensor
([
self
.
cfg_scale
],
dtype
=
torch
.
float32
).
to
(
x
.
device
)
s
=
torch
.
tensor
([
self
.
cfg_scale
],
dtype
=
torch
.
float32
,
device
=
x
.
device
)
cfg_embed
=
guidance_scale_embedding
(
s
,
embedding_dim
=
256
,
cfg_range
=
(
1.0
,
6.0
),
target_range
=
1000.0
,
dtype
=
torch
.
float32
).
type_as
(
x
)
cfg_embed
=
guidance_scale_embedding
(
s
,
embedding_dim
=
256
,
cfg_range
=
(
1.0
,
6.0
),
target_range
=
1000.0
,
dtype
=
torch
.
float32
).
type_as
(
x
)
cfg_embed
=
weights
.
cfg_cond_proj_1
.
apply
(
cfg_embed
)
cfg_embed
=
weights
.
cfg_cond_proj_1
.
apply
(
cfg_embed
)
cfg_embed
=
torch
.
nn
.
functional
.
silu
(
cfg_embed
)
cfg_embed
=
torch
.
nn
.
functional
.
silu
(
cfg_embed
)
...
@@ -117,6 +117,7 @@ class WanPreInfer:
...
@@ -117,6 +117,7 @@ class WanPreInfer:
del
context_clip
del
context_clip
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
grid_sizes
=
GridOutput
(
tensor
=
grid_sizes
,
tuple
=
(
grid_sizes
[
0
][
0
].
item
(),
grid_sizes
[
0
][
1
].
item
(),
grid_sizes
[
0
][
2
].
item
()))
return
WanPreInferModuleOutput
(
return
WanPreInferModuleOutput
(
embed
=
embed
,
embed
=
embed
,
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
091a2a85
...
@@ -96,7 +96,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -96,7 +96,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
)
y_out
=
self
.
infer_self_attn
(
y_out
=
self
.
infer_self_attn
(
block
.
compute_phases
[
0
],
block
.
compute_phases
[
0
],
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
grid_sizes
.
tuple
,
x
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
freqs
,
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
091a2a85
...
@@ -6,7 +6,7 @@ from lightx2v.utils.envs import *
...
@@ -6,7 +6,7 @@ from lightx2v.utils.envs import *
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
f
,
h
,
w
=
grid_sizes
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
[
[
...
@@ -24,7 +24,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
...
@@ -24,7 +24,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
world_size
=
dist
.
get_world_size
(
seq_p_group
)
world_size
=
dist
.
get_world_size
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
cur_rank
=
dist
.
get_rank
(
seq_p_group
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
f
,
h
,
w
=
grid_sizes
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
[
[
...
@@ -43,7 +43,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
...
@@ -43,7 +43,7 @@ def compute_freqs_dist(s, c, grid_sizes, freqs, seq_p_group):
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
def
compute_freqs_causvid
(
c
,
grid_sizes
,
freqs
,
start_frame
=
0
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
[
0
]
f
,
h
,
w
=
grid_sizes
seq_len
=
f
*
h
*
w
seq_len
=
f
*
h
*
w
freqs_i
=
torch
.
cat
(
freqs_i
=
torch
.
cat
(
[
[
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment