Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
47b3ce2f
Unverified
Commit
47b3ce2f
authored
Nov 27, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 27, 2025
Browse files
Update wan infer rope (#518)
parent
5f277e80
Changes
10
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
179 additions
and
62 deletions
+179
-62
lightx2v/models/networks/wan/infer/audio/post_infer.py
lightx2v/models/networks/wan/infer/audio/post_infer.py
+0
-2
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+3
-15
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+0
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+2
-13
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
+15
-2
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
...dels/networks/wan/infer/self_forcing/transformer_infer.py
+3
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+27
-28
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+79
-0
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+5
-0
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+45
-0
No files found.
lightx2v/models/networks/wan/infer/audio/post_infer.py
View file @
47b3ce2f
...
@@ -10,8 +10,6 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -10,8 +10,6 @@ class WanAudioPostInfer(WanPostInfer):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]]
t
,
h
,
w
=
pre_infer_out
.
grid_sizes
.
tuple
t
,
h
,
w
=
pre_infer_out
.
grid_sizes
.
tuple
grid_sizes
=
(
t
-
1
,
h
,
w
)
grid_sizes
=
(
t
-
1
,
h
,
w
)
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
47b3ce2f
...
@@ -4,27 +4,17 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
...
@@ -4,27 +4,17 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..module_io
import
GridOutput
,
WanPreInferModuleOutput
from
..module_io
import
GridOutput
,
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
..utils
import
sinusoidal_embedding_1d
class
WanAudioPreInfer
(
WanPreInfer
):
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
config
=
config
self
.
config
=
config
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
...
@@ -65,14 +55,14 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -65,14 +55,14 @@ class WanAudioPreInfer(WanPreInfer):
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
#
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
).
squeeze
(
0
)
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
).
squeeze
(
0
)
####for r2v # zero temporl component corresponding to ref embeddings
####for r2v # zero temporl component corresponding to ref embeddings
self
.
freqs
[
grid_sizes_t
:,
:
self
.
rope_t_dim
]
=
0
#
self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0
grid_sizes_t
+=
1
grid_sizes_t
+=
1
person_mask_latens
=
inputs
[
"person_mask_latens"
]
person_mask_latens
=
inputs
[
"person_mask_latens"
]
...
@@ -126,8 +116,6 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -126,8 +116,6 @@ class WanAudioPreInfer(WanPreInfer):
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
x
=
x
,
x
=
x
,
embed0
=
embed0
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
context
=
context
,
adapter_args
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"person_mask_latens"
:
person_mask_latens
},
adapter_args
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"person_mask_latens"
:
person_mask_latens
},
)
)
lightx2v/models/networks/wan/infer/module_io.py
View file @
47b3ce2f
...
@@ -16,8 +16,6 @@ class WanPreInferModuleOutput:
...
@@ -16,8 +16,6 @@ class WanPreInferModuleOutput:
grid_sizes
:
GridOutput
grid_sizes
:
GridOutput
x
:
torch
.
Tensor
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
context
:
torch
.
Tensor
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
47b3ce2f
...
@@ -3,26 +3,17 @@ import torch
...
@@ -3,26 +3,17 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
from
.utils
import
guidance_scale_embedding
,
sinusoidal_embedding_1d
class
WanPreInfer
:
class
WanPreInfer
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
self
.
config
=
config
self
.
config
=
config
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
...
@@ -71,7 +62,7 @@ class WanPreInfer:
...
@@ -71,7 +62,7 @@ class WanPreInfer:
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
#
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
if
self
.
enable_dynamic_cfg
:
...
@@ -130,8 +121,6 @@ class WanPreInfer:
...
@@ -130,8 +121,6 @@ class WanPreInfer:
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
context
=
context
,
adapter_args
=
{
"motion_vec"
:
motion_vec
},
adapter_args
=
{
"motion_vec"
:
motion_vec
},
)
)
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
View file @
47b3ce2f
from
dataclasses
import
dataclass
import
torch
import
torch
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -24,6 +26,17 @@ def rope_params(max_seq_len, dim, theta=10000):
...
@@ -24,6 +26,17 @@ def rope_params(max_seq_len, dim, theta=10000):
return
freqs
return
freqs
@
dataclass
class
WanSFPreInferModuleOutput
:
embed
:
torch
.
Tensor
grid_sizes
:
GridOutput
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
class
WanSFPreInfer
(
WanPreInfer
):
class
WanSFPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -87,7 +100,7 @@ class WanSFPreInfer(WanPreInfer):
...
@@ -87,7 +100,7 @@ class WanSFPreInfer(WanPreInfer):
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
return
WanPreInferModuleOutput
(
return
Wan
SF
PreInferModuleOutput
(
embed
=
embed
,
embed
=
embed
,
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
x
=
x
.
squeeze
(
0
),
...
...
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
View file @
47b3ce2f
...
@@ -50,6 +50,9 @@ class WanSFTransformerInfer(WanTransformerInfer):
...
@@ -50,6 +50,9 @@ class WanSFTransformerInfer(WanTransformerInfer):
self
.
infer_func
=
self
.
infer_with_kvcache
self
.
infer_func
=
self
.
infer_with_kvcache
def
get_scheduler_values
(
self
):
pass
def
_initialize_kv_cache
(
self
,
dtype
,
device
):
def
_initialize_kv_cache
(
self
,
dtype
,
device
):
"""
"""
Initialize a Per-GPU KV cache for the Wan model.
Initialize a Per-GPU KV cache for the Wan model.
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
47b3ce2f
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.utils
import
apply_
rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_dist
from
.utils
import
apply_
wan_rope_with_chunk
,
apply_wan_rope_with_flashinfer
,
apply_wan_rope_with_torch
class
WanTransformerInfer
(
BaseTransformerInfer
):
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
@@ -20,11 +20,16 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -20,11 +20,16 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
if
config
.
get
(
"ro
tary_chunk"
,
False
)
:
if
self
.
config
.
get
(
"ro
pe_type"
,
"flashinfer"
)
==
"flashinfer"
:
chunk_size
=
config
.
get
(
"ro
tary
_chunk
_size"
,
100
)
if
self
.
config
.
get
(
"ro
pe
_chunk
"
,
False
):
self
.
apply_ro
tary_emb
_func
=
partial
(
apply_
rotary_emb
_chunk
,
chunk_size
=
chunk_size
)
self
.
apply_ro
pe
_func
=
partial
(
apply_
wan_rope_with
_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_flashinfer
)
else
:
else
:
self
.
apply_rotary_emb_func
=
apply_rotary_emb
self
.
apply_rope_func
=
apply_wan_rope_with_flashinfer
else
:
if
self
.
config
.
get
(
"rope_chunk"
,
False
):
self
.
apply_rope_func
=
partial
(
apply_wan_rope_with_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_torch
)
else
:
self
.
apply_rope_func
=
apply_wan_rope_with_torch
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
...
@@ -35,21 +40,20 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -35,21 +40,20 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
seq_p_group
=
None
self
.
seq_p_group
=
None
self
.
infer_func
=
self
.
infer_without_offload
self
.
infer_func
=
self
.
infer_without_offload
self
.
cos_sin
=
None
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
def
get_scheduler_values
(
self
):
if
self
.
config
[
"seq_parallel"
]:
self
.
cos_sin
=
self
.
scheduler
.
cos_sin
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
self
.
get_scheduler_values
()
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
...
@@ -97,10 +101,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -97,10 +101,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
)
y_out
=
self
.
infer_self_attn
(
y_out
=
self
.
infer_self_attn
(
block
.
compute_phases
[
0
],
block
.
compute_phases
[
0
],
pre_infer_out
.
grid_sizes
.
tuple
,
x
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
shift_msa
,
scale_msa
,
scale_msa
,
)
)
...
@@ -129,7 +130,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -129,7 +130,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
infer_self_attn
(
self
,
phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
def
infer_self_attn
(
self
,
phase
,
x
,
shift_msa
,
scale_msa
):
cos_sin
=
self
.
cos_sin
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
...
@@ -153,16 +155,13 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -153,16 +155,13 @@ class WanTransformerInfer(BaseTransformerInfer):
k
=
phase
.
self_attn_norm_k
.
apply
(
phase
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
k
=
phase
.
self_attn_norm_k
.
apply
(
phase
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
v
=
phase
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
v
=
phase
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
q
,
k
=
self
.
apply_rope_func
(
q
,
k
,
cos_sin
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
img_qkv_len
=
q
.
shape
[
0
]
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
k_lens
=
torch
.
empty_like
(
seq_lens
).
fill_
(
freqs_i
.
size
(
0
))
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
k_lens
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
del
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
...
@@ -170,8 +169,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -170,8 +169,8 @@ class WanTransformerInfer(BaseTransformerInfer):
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
img_qkv_len
=
q
.
shape
[
0
]
,
img_qkv_len
=
img_qkv_len
,
cu_seqlens_qkv
=
cu_seqlens_q
,
cu_seqlens_qkv
=
cu_seqlens_q
kv
,
attention_module
=
phase
.
self_attn_1
,
attention_module
=
phase
.
self_attn_1
,
seq_p_group
=
self
.
seq_p_group
,
seq_p_group
=
self
.
seq_p_group
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
...
@@ -181,10 +180,10 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -181,10 +180,10 @@ class WanTransformerInfer(BaseTransformerInfer):
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
kv
,
cu_seqlens_kv
=
cu_seqlens_
k
,
cu_seqlens_kv
=
cu_seqlens_
qkv
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_q
=
img_qkv_len
,
max_seqlen_kv
=
k
.
size
(
0
)
,
max_seqlen_kv
=
img_qkv_len
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
47b3ce2f
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
flashinfer.rope
import
apply_rope_with_cos_sin_cache_inplace
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
def
apply_wan_rope_with_torch
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
n
=
xq
.
size
(
1
)
seq_len
=
cos_sin_cache
.
size
(
0
)
xq
=
torch
.
view_as_complex
(
xq
[:
seq_len
].
to
(
torch
.
float32
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
xk
=
torch
.
view_as_complex
(
xk
[:
seq_len
].
to
(
torch
.
float32
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
# Apply rotary embedding
xq
=
torch
.
view_as_real
(
xq
*
cos_sin_cache
).
flatten
(
2
)
xk
=
torch
.
view_as_real
(
xk
*
cos_sin_cache
).
flatten
(
2
)
xq
=
torch
.
cat
([
xq
,
xq
[
seq_len
:]])
xk
=
torch
.
cat
([
xk
,
xk
[
seq_len
:]])
return
xq
.
to
(
GET_DTYPE
()),
xk
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_chunk
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
chunk_size
:
int
,
rope_func
,
):
seq_len
=
cos_sin_cache
.
size
(
0
)
xq_output_chunks
=
[]
xk_output_chunks
=
[]
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
end
=
min
(
start
+
chunk_size
,
seq_len
)
xq_chunk
=
xq
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
xq_chunk
,
xk_chunk
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
xq_output_chunks
.
append
(
xq_chunk
)
xk_output_chunks
.
append
(
xk_chunk
)
torch
.
cuda
.
empty_cache
()
x_q
=
torch
.
cat
(
xq_output_chunks
,
dim
=
0
)
del
xq_output_chunks
torch
.
cuda
.
empty_cache
()
x_k
=
torch
.
cat
(
xk_output_chunks
,
dim
=
0
)
del
xk_output_chunks
torch
.
cuda
.
empty_cache
()
return
x_q
.
to
(
GET_DTYPE
()),
x_k
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_flashinfer
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
L
,
H
,
D
=
xq
.
shape
query
=
xq
.
reshape
(
L
,
H
*
D
).
contiguous
()
key
=
xk
.
reshape
(
L
,
H
*
D
).
contiguous
()
positions
=
torch
.
arange
(
L
,
device
=
"cpu"
,
dtype
=
torch
.
long
).
to
(
xq
.
device
,
non_blocking
=
True
)
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
head_size
=
D
,
cos_sin_cache
=
cos_sin_cache
,
is_neox
=
False
,
)
xq_out
=
query
.
view
(
L
,
H
,
D
)
xk_out
=
key
.
view
(
L
,
H
,
D
)
return
xq_out
,
xk_out
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
f
,
h
,
w
=
grid_sizes
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
47b3ce2f
...
@@ -12,6 +12,8 @@ from lightx2v.utils.utils import masks_like
...
@@ -12,6 +12,8 @@ from lightx2v.utils.utils import masks_like
class
EulerScheduler
(
WanScheduler
):
class
EulerScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
if
self
.
config
[
"parallel"
]:
if
self
.
config
[
"parallel"
]:
self
.
sp_size
=
self
.
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
self
.
sp_size
=
self
.
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
...
@@ -83,6 +85,9 @@ class EulerScheduler(WanScheduler):
...
@@ -83,6 +85,9 @@ class EulerScheduler(WanScheduler):
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
self
.
freqs
[
latent_shape
[
1
]
//
self
.
patch_size
[
0
]
:,
:
self
.
rope_t_dim
]
=
0
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
]
+
1
,
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
step_post
(
self
):
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
47b3ce2f
...
@@ -14,6 +14,8 @@ class WanScheduler(BaseScheduler):
...
@@ -14,6 +14,8 @@ class WanScheduler(BaseScheduler):
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
shift
=
1
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
num_train_timesteps
=
1000
self
.
disable_corrector
=
[]
self
.
disable_corrector
=
[]
...
@@ -21,6 +23,24 @@ class WanScheduler(BaseScheduler):
...
@@ -21,6 +23,24 @@ class WanScheduler(BaseScheduler):
self
.
noise_pred
=
None
self
.
noise_pred
=
None
self
.
sample_guide_scale
=
self
.
config
[
"sample_guide_scale"
]
self
.
sample_guide_scale
=
self
.
config
[
"sample_guide_scale"
]
self
.
caching_records_2
=
[
True
]
*
self
.
config
[
"infer_steps"
]
self
.
caching_records_2
=
[
True
]
*
self
.
config
[
"infer_steps"
]
self
.
head_size
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
]
self
.
freqs
=
torch
.
cat
(
[
self
.
rope_params
(
1024
,
self
.
head_size
-
4
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
torch
.
arange
(
max_seq_len
),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float32
).
div
(
dim
)),
)
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
@@ -47,6 +67,31 @@ class WanScheduler(BaseScheduler):
...
@@ -47,6 +67,31 @@ class WanScheduler(BaseScheduler):
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
prepare_cos_sin
(
self
,
grid_sizes
):
c
=
self
.
head_size
//
2
freqs
=
self
.
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
seq_len
=
f
*
h
*
w
cos_sin
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
)
if
self
.
config
.
get
(
"rope_type"
,
"flashinfer"
)
==
"flashinfer"
:
cos_sin
=
cos_sin
.
reshape
(
seq_len
,
-
1
)
# Extract cos and sin parts separately and concatenate
cos_half
=
cos_sin
.
real
.
contiguous
()
sin_half
=
cos_sin
.
imag
.
contiguous
()
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
else
:
cos_sin
=
cos_sin
.
reshape
(
seq_len
,
1
,
-
1
)
return
cos_sin
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment