Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
47b3ce2f
Unverified
Commit
47b3ce2f
authored
Nov 27, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 27, 2025
Browse files
Update wan infer rope (#518)
parent
5f277e80
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
179 additions
and
62 deletions
+179
-62
lightx2v/models/networks/wan/infer/audio/post_infer.py
lightx2v/models/networks/wan/infer/audio/post_infer.py
+0
-2
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+3
-15
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+0
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+2
-13
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
+15
-2
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
...dels/networks/wan/infer/self_forcing/transformer_infer.py
+3
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+27
-28
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+79
-0
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+5
-0
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+45
-0
No files found.
lightx2v/models/networks/wan/infer/audio/post_infer.py
View file @
47b3ce2f
...
...
@@ -10,8 +10,6 @@ class WanAudioPostInfer(WanPostInfer):
@
torch
.
no_grad
()
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]]
t
,
h
,
w
=
pre_infer_out
.
grid_sizes
.
tuple
grid_sizes
=
(
t
-
1
,
h
,
w
)
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
47b3ce2f
...
...
@@ -4,27 +4,17 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from
lightx2v.utils.envs
import
*
from
..module_io
import
GridOutput
,
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
..utils
import
sinusoidal_embedding_1d
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
config
=
config
self
.
task
=
config
[
"task"
]
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
...
...
@@ -65,14 +55,14 @@ class WanAudioPreInfer(WanPreInfer):
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
#
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
).
squeeze
(
0
)
####for r2v # zero temporl component corresponding to ref embeddings
self
.
freqs
[
grid_sizes_t
:,
:
self
.
rope_t_dim
]
=
0
#
self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0
grid_sizes_t
+=
1
person_mask_latens
=
inputs
[
"person_mask_latens"
]
...
...
@@ -126,8 +116,6 @@ class WanAudioPreInfer(WanPreInfer):
grid_sizes
=
grid_sizes
,
x
=
x
,
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
adapter_args
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"person_mask_latens"
:
person_mask_latens
},
)
lightx2v/models/networks/wan/infer/module_io.py
View file @
47b3ce2f
...
...
@@ -16,8 +16,6 @@ class WanPreInferModuleOutput:
grid_sizes
:
GridOutput
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
47b3ce2f
...
...
@@ -3,26 +3,17 @@ import torch
from
lightx2v.utils.envs
import
*
from
.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
from
.utils
import
guidance_scale_embedding
,
sinusoidal_embedding_1d
class
WanPreInfer
:
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
self
.
config
=
config
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
task
=
config
[
"task"
]
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
...
...
@@ -71,7 +62,7 @@ class WanPreInfer:
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
#
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
...
...
@@ -130,8 +121,6 @@ class WanPreInfer:
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
adapter_args
=
{
"motion_vec"
:
motion_vec
},
)
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
View file @
47b3ce2f
from
dataclasses
import
dataclass
import
torch
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
...
...
@@ -24,6 +26,17 @@ def rope_params(max_seq_len, dim, theta=10000):
return
freqs
@
dataclass
class
WanSFPreInferModuleOutput
:
embed
:
torch
.
Tensor
grid_sizes
:
GridOutput
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
class
WanSFPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
...
...
@@ -87,7 +100,7 @@ class WanSFPreInfer(WanPreInfer):
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
return
WanPreInferModuleOutput
(
return
Wan
SF
PreInferModuleOutput
(
embed
=
embed
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
...
...
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
View file @
47b3ce2f
...
...
@@ -50,6 +50,9 @@ class WanSFTransformerInfer(WanTransformerInfer):
self
.
infer_func
=
self
.
infer_with_kvcache
def
get_scheduler_values
(
self
):
pass
def
_initialize_kv_cache
(
self
,
dtype
,
device
):
"""
Initialize a Per-GPU KV cache for the Wan model.
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
47b3ce2f
...
...
@@ -5,7 +5,7 @@ import torch
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
.utils
import
apply_
rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_dist
from
.utils
import
apply_
wan_rope_with_chunk
,
apply_wan_rope_with_flashinfer
,
apply_wan_rope_with_torch
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
...
@@ -20,11 +20,16 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
if
config
.
get
(
"rotary_chunk"
,
False
):
chunk_size
=
config
.
get
(
"rotary_chunk_size"
,
100
)
self
.
apply_rotary_emb_func
=
partial
(
apply_rotary_emb_chunk
,
chunk_size
=
chunk_size
)
if
self
.
config
.
get
(
"rope_type"
,
"flashinfer"
)
==
"flashinfer"
:
if
self
.
config
.
get
(
"rope_chunk"
,
False
):
self
.
apply_rope_func
=
partial
(
apply_wan_rope_with_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_flashinfer
)
else
:
self
.
apply_rope_func
=
apply_wan_rope_with_flashinfer
else
:
self
.
apply_rotary_emb_func
=
apply_rotary_emb
if
self
.
config
.
get
(
"rope_chunk"
,
False
):
self
.
apply_rope_func
=
partial
(
apply_wan_rope_with_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_torch
)
else
:
self
.
apply_rope_func
=
apply_wan_rope_with_torch
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
...
...
@@ -35,21 +40,20 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
seq_p_group
=
None
self
.
infer_func
=
self
.
infer_without_offload
self
.
cos_sin
=
None
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
if
self
.
config
[
"seq_parallel"
]:
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
def
get_scheduler_values
(
self
):
self
.
cos_sin
=
self
.
scheduler
.
cos_sin
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
pre_infer_out
):
self
.
get_scheduler_values
()
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
...
...
@@ -97,10 +101,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
y_out
=
self
.
infer_self_attn
(
block
.
compute_phases
[
0
],
pre_infer_out
.
grid_sizes
.
tuple
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
scale_msa
,
)
...
...
@@ -129,7 +130,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
infer_self_attn
(
self
,
phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
def
infer_self_attn
(
self
,
phase
,
x
,
shift_msa
,
scale_msa
):
cos_sin
=
self
.
cos_sin
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
...
...
@@ -153,16 +155,13 @@ class WanTransformerInfer(BaseTransformerInfer):
k
=
phase
.
self_attn_norm_k
.
apply
(
phase
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
v
=
phase
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
q
,
k
=
self
.
apply_rope_func
(
q
,
k
,
cos_sin
)
k
_len
s
=
torch
.
empty_like
(
seq_lens
).
fill_
(
freqs_i
.
size
(
0
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
k_lens
)
img_qkv
_len
=
q
.
shape
[
0
]
cu_seqlens_q
kv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
if
self
.
clean_cuda_cache
:
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
del
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
...
...
@@ -170,8 +169,8 @@ class WanTransformerInfer(BaseTransformerInfer):
q
=
q
,
k
=
k
,
v
=
v
,
img_qkv_len
=
q
.
shape
[
0
]
,
cu_seqlens_qkv
=
cu_seqlens_q
,
img_qkv_len
=
img_qkv_len
,
cu_seqlens_qkv
=
cu_seqlens_q
kv
,
attention_module
=
phase
.
self_attn_1
,
seq_p_group
=
self
.
seq_p_group
,
model_cls
=
self
.
config
[
"model_cls"
],
...
...
@@ -181,10 +180,10 @@ class WanTransformerInfer(BaseTransformerInfer):
q
=
q
,
k
=
k
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_kv
=
cu_seqlens_
k
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_kv
=
k
.
size
(
0
)
,
cu_seqlens_q
=
cu_seqlens_q
kv
,
cu_seqlens_kv
=
cu_seqlens_
qkv
,
max_seqlen_q
=
img_qkv_len
,
max_seqlen_kv
=
img_qkv_len
,
model_cls
=
self
.
config
[
"model_cls"
],
)
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
47b3ce2f
import
torch
import
torch.distributed
as
dist
from
flashinfer.rope
import
apply_rope_with_cos_sin_cache_inplace
from
lightx2v.utils.envs
import
*
def
apply_wan_rope_with_torch
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
n
=
xq
.
size
(
1
)
seq_len
=
cos_sin_cache
.
size
(
0
)
xq
=
torch
.
view_as_complex
(
xq
[:
seq_len
].
to
(
torch
.
float32
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
xk
=
torch
.
view_as_complex
(
xk
[:
seq_len
].
to
(
torch
.
float32
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
# Apply rotary embedding
xq
=
torch
.
view_as_real
(
xq
*
cos_sin_cache
).
flatten
(
2
)
xk
=
torch
.
view_as_real
(
xk
*
cos_sin_cache
).
flatten
(
2
)
xq
=
torch
.
cat
([
xq
,
xq
[
seq_len
:]])
xk
=
torch
.
cat
([
xk
,
xk
[
seq_len
:]])
return
xq
.
to
(
GET_DTYPE
()),
xk
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_chunk
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
chunk_size
:
int
,
rope_func
,
):
seq_len
=
cos_sin_cache
.
size
(
0
)
xq_output_chunks
=
[]
xk_output_chunks
=
[]
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
end
=
min
(
start
+
chunk_size
,
seq_len
)
xq_chunk
=
xq
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
xq_chunk
,
xk_chunk
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
xq_output_chunks
.
append
(
xq_chunk
)
xk_output_chunks
.
append
(
xk_chunk
)
torch
.
cuda
.
empty_cache
()
x_q
=
torch
.
cat
(
xq_output_chunks
,
dim
=
0
)
del
xq_output_chunks
torch
.
cuda
.
empty_cache
()
x_k
=
torch
.
cat
(
xk_output_chunks
,
dim
=
0
)
del
xk_output_chunks
torch
.
cuda
.
empty_cache
()
return
x_q
.
to
(
GET_DTYPE
()),
x_k
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_flashinfer
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
L
,
H
,
D
=
xq
.
shape
query
=
xq
.
reshape
(
L
,
H
*
D
).
contiguous
()
key
=
xk
.
reshape
(
L
,
H
*
D
).
contiguous
()
positions
=
torch
.
arange
(
L
,
device
=
"cpu"
,
dtype
=
torch
.
long
).
to
(
xq
.
device
,
non_blocking
=
True
)
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
head_size
=
D
,
cos_sin_cache
=
cos_sin_cache
,
is_neox
=
False
,
)
xq_out
=
query
.
view
(
L
,
H
,
D
)
xk_out
=
key
.
view
(
L
,
H
,
D
)
return
xq_out
,
xk_out
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
47b3ce2f
...
...
@@ -12,6 +12,8 @@ from lightx2v.utils.utils import masks_like
class
EulerScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
if
self
.
config
[
"parallel"
]:
self
.
sp_size
=
self
.
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
...
...
@@ -83,6 +85,9 @@ class EulerScheduler(WanScheduler):
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
self
.
freqs
[
latent_shape
[
1
]
//
self
.
patch_size
[
0
]
:,
:
self
.
rope_t_dim
]
=
0
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
]
+
1
,
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
47b3ce2f
...
...
@@ -14,6 +14,8 @@ class WanScheduler(BaseScheduler):
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
disable_corrector
=
[]
...
...
@@ -21,6 +23,24 @@ class WanScheduler(BaseScheduler):
self
.
noise_pred
=
None
self
.
sample_guide_scale
=
self
.
config
[
"sample_guide_scale"
]
self
.
caching_records_2
=
[
True
]
*
self
.
config
[
"infer_steps"
]
self
.
head_size
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
]
self
.
freqs
=
torch
.
cat
(
[
self
.
rope_params
(
1024
,
self
.
head_size
-
4
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
torch
.
arange
(
max_seq_len
),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float32
).
div
(
dim
)),
)
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
...
@@ -47,6 +67,31 @@ class WanScheduler(BaseScheduler):
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
prepare_cos_sin
(
self
,
grid_sizes
):
c
=
self
.
head_size
//
2
freqs
=
self
.
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
seq_len
=
f
*
h
*
w
cos_sin
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
)
if
self
.
config
.
get
(
"rope_type"
,
"flashinfer"
)
==
"flashinfer"
:
cos_sin
=
cos_sin
.
reshape
(
seq_len
,
-
1
)
# Extract cos and sin parts separately and concatenate
cos_half
=
cos_sin
.
real
.
contiguous
()
sin_half
=
cos_sin
.
imag
.
contiguous
()
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
else
:
cos_sin
=
cos_sin
.
reshape
(
seq_len
,
1
,
-
1
)
return
cos_sin
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment