Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
47b3ce2f
"vscode:/vscode.git/clone" did not exist on "ea575d03bf447e95cf5b9b66d79bca0856fc4b07"
Unverified
Commit
47b3ce2f
authored
Nov 27, 2025
by
Yang Yong (雍洋)
Committed by
GitHub
Nov 27, 2025
Browse files
Update wan infer rope (#518)
parent
5f277e80
Changes
10
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
179 additions
and
62 deletions
+179
-62
lightx2v/models/networks/wan/infer/audio/post_infer.py
lightx2v/models/networks/wan/infer/audio/post_infer.py
+0
-2
lightx2v/models/networks/wan/infer/audio/pre_infer.py
lightx2v/models/networks/wan/infer/audio/pre_infer.py
+3
-15
lightx2v/models/networks/wan/infer/module_io.py
lightx2v/models/networks/wan/infer/module_io.py
+0
-2
lightx2v/models/networks/wan/infer/pre_infer.py
lightx2v/models/networks/wan/infer/pre_infer.py
+2
-13
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
+15
-2
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
...dels/networks/wan/infer/self_forcing/transformer_infer.py
+3
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+27
-28
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+79
-0
lightx2v/models/schedulers/wan/audio/scheduler.py
lightx2v/models/schedulers/wan/audio/scheduler.py
+5
-0
lightx2v/models/schedulers/wan/scheduler.py
lightx2v/models/schedulers/wan/scheduler.py
+45
-0
No files found.
lightx2v/models/networks/wan/infer/audio/post_infer.py
View file @
47b3ce2f
...
@@ -10,8 +10,6 @@ class WanAudioPostInfer(WanPostInfer):
...
@@ -10,8 +10,6 @@ class WanAudioPostInfer(WanPostInfer):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
x
,
pre_infer_out
):
def
infer
(
self
,
x
,
pre_infer_out
):
x
=
x
[:
pre_infer_out
.
seq_lens
[
0
]]
t
,
h
,
w
=
pre_infer_out
.
grid_sizes
.
tuple
t
,
h
,
w
=
pre_infer_out
.
grid_sizes
.
tuple
grid_sizes
=
(
t
-
1
,
h
,
w
)
grid_sizes
=
(
t
-
1
,
h
,
w
)
...
...
lightx2v/models/networks/wan/infer/audio/pre_infer.py
View file @
47b3ce2f
...
@@ -4,27 +4,17 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
...
@@ -4,27 +4,17 @@ from lightx2v.models.networks.wan.infer.pre_infer import WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
..module_io
import
GridOutput
,
WanPreInferModuleOutput
from
..module_io
import
GridOutput
,
WanPreInferModuleOutput
from
..utils
import
rope_params
,
sinusoidal_embedding_1d
from
..utils
import
sinusoidal_embedding_1d
class
WanAudioPreInfer
(
WanPreInfer
):
class
WanAudioPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
config
=
config
self
.
config
=
config
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
...
@@ -65,14 +55,14 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -65,14 +55,14 @@ class WanAudioPreInfer(WanPreInfer):
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
x
=
weights
.
patch_embedding
.
apply
(
x
.
unsqueeze
(
0
))
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
#
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
y
=
weights
.
patch_embedding
.
apply
(
y
.
unsqueeze
(
0
))
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
y
=
y
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
).
squeeze
(
0
)
x
=
torch
.
cat
([
x
,
y
],
dim
=
1
).
squeeze
(
0
)
####for r2v # zero temporl component corresponding to ref embeddings
####for r2v # zero temporl component corresponding to ref embeddings
self
.
freqs
[
grid_sizes_t
:,
:
self
.
rope_t_dim
]
=
0
#
self.freqs[grid_sizes_t:, : self.rope_t_dim] = 0
grid_sizes_t
+=
1
grid_sizes_t
+=
1
person_mask_latens
=
inputs
[
"person_mask_latens"
]
person_mask_latens
=
inputs
[
"person_mask_latens"
]
...
@@ -126,8 +116,6 @@ class WanAudioPreInfer(WanPreInfer):
...
@@ -126,8 +116,6 @@ class WanAudioPreInfer(WanPreInfer):
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
x
=
x
,
x
=
x
,
embed0
=
embed0
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
context
=
context
,
adapter_args
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"person_mask_latens"
:
person_mask_latens
},
adapter_args
=
{
"audio_encoder_output"
:
inputs
[
"audio_encoder_output"
],
"person_mask_latens"
:
person_mask_latens
},
)
)
lightx2v/models/networks/wan/infer/module_io.py
View file @
47b3ce2f
...
@@ -16,8 +16,6 @@ class WanPreInferModuleOutput:
...
@@ -16,8 +16,6 @@ class WanPreInferModuleOutput:
grid_sizes
:
GridOutput
grid_sizes
:
GridOutput
x
:
torch
.
Tensor
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
context
:
torch
.
Tensor
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
adapter_args
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
conditional_dict
:
Dict
[
str
,
Any
]
=
field
(
default_factory
=
dict
)
lightx2v/models/networks/wan/infer/pre_infer.py
View file @
47b3ce2f
...
@@ -3,26 +3,17 @@ import torch
...
@@ -3,26 +3,17 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
.utils
import
guidance_scale_embedding
,
rope_params
,
sinusoidal_embedding_1d
from
.utils
import
guidance_scale_embedding
,
sinusoidal_embedding_1d
class
WanPreInfer
:
class
WanPreInfer
:
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
assert
(
config
[
"dim"
]
%
config
[
"num_heads"
])
==
0
and
(
config
[
"dim"
]
//
config
[
"num_heads"
])
%
2
==
0
self
.
config
=
config
self
.
config
=
config
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
task
=
config
[
"task"
]
self
.
task
=
config
[
"task"
]
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
device
=
torch
.
device
(
self
.
config
.
get
(
"run_device"
,
"cuda"
))
self
.
freqs
=
torch
.
cat
(
[
rope_params
(
1024
,
d
-
4
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
rope_params
(
1024
,
2
*
(
d
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
freq_dim
=
config
[
"freq_dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
dim
=
config
[
"dim"
]
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
self
.
enable_dynamic_cfg
=
config
.
get
(
"enable_dynamic_cfg"
,
False
)
...
@@ -71,7 +62,7 @@ class WanPreInfer:
...
@@ -71,7 +62,7 @@ class WanPreInfer:
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
=
x
.
shape
[
2
:]
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
()
seq_lens
=
torch
.
tensor
(
x
.
size
(
1
),
dtype
=
torch
.
int32
,
device
=
x
.
device
).
unsqueeze
(
0
)
#
seq_lens = torch.tensor(x.size(1), dtype=torch.int32, device=x.device).unsqueeze(0)
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
embed
=
sinusoidal_embedding_1d
(
self
.
freq_dim
,
t
.
flatten
())
if
self
.
enable_dynamic_cfg
:
if
self
.
enable_dynamic_cfg
:
...
@@ -130,8 +121,6 @@ class WanPreInfer:
...
@@ -130,8 +121,6 @@ class WanPreInfer:
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
x
=
x
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
embed0
=
embed0
.
squeeze
(
0
),
seq_lens
=
seq_lens
,
freqs
=
self
.
freqs
,
context
=
context
,
context
=
context
,
adapter_args
=
{
"motion_vec"
:
motion_vec
},
adapter_args
=
{
"motion_vec"
:
motion_vec
},
)
)
lightx2v/models/networks/wan/infer/self_forcing/pre_infer.py
View file @
47b3ce2f
from
dataclasses
import
dataclass
import
torch
import
torch
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
,
WanPreInferModuleOutput
from
lightx2v.models.networks.wan.infer.module_io
import
GridOutput
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.models.networks.wan.infer.pre_infer
import
WanPreInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
...
@@ -24,6 +26,17 @@ def rope_params(max_seq_len, dim, theta=10000):
...
@@ -24,6 +26,17 @@ def rope_params(max_seq_len, dim, theta=10000):
return
freqs
return
freqs
@
dataclass
class
WanSFPreInferModuleOutput
:
embed
:
torch
.
Tensor
grid_sizes
:
GridOutput
x
:
torch
.
Tensor
embed0
:
torch
.
Tensor
seq_lens
:
torch
.
Tensor
freqs
:
torch
.
Tensor
context
:
torch
.
Tensor
class
WanSFPreInfer
(
WanPreInfer
):
class
WanSFPreInfer
(
WanPreInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -87,7 +100,7 @@ class WanSFPreInfer(WanPreInfer):
...
@@ -87,7 +100,7 @@ class WanSFPreInfer(WanPreInfer):
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
grid_sizes
=
GridOutput
(
tensor
=
torch
.
tensor
([[
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
]],
dtype
=
torch
.
int32
,
device
=
x
.
device
),
tuple
=
(
grid_sizes_t
,
grid_sizes_h
,
grid_sizes_w
))
return
WanPreInferModuleOutput
(
return
Wan
SF
PreInferModuleOutput
(
embed
=
embed
,
embed
=
embed
,
grid_sizes
=
grid_sizes
,
grid_sizes
=
grid_sizes
,
x
=
x
.
squeeze
(
0
),
x
=
x
.
squeeze
(
0
),
...
...
lightx2v/models/networks/wan/infer/self_forcing/transformer_infer.py
View file @
47b3ce2f
...
@@ -50,6 +50,9 @@ class WanSFTransformerInfer(WanTransformerInfer):
...
@@ -50,6 +50,9 @@ class WanSFTransformerInfer(WanTransformerInfer):
self
.
infer_func
=
self
.
infer_with_kvcache
self
.
infer_func
=
self
.
infer_with_kvcache
def
get_scheduler_values
(
self
):
pass
def
_initialize_kv_cache
(
self
,
dtype
,
device
):
def
_initialize_kv_cache
(
self
,
dtype
,
device
):
"""
"""
Initialize a Per-GPU KV cache for the Wan model.
Initialize a Per-GPU KV cache for the Wan model.
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
47b3ce2f
...
@@ -5,7 +5,7 @@ import torch
...
@@ -5,7 +5,7 @@ import torch
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.common.transformer_infer.transformer_infer
import
BaseTransformerInfer
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
.utils
import
apply_
rotary_emb
,
apply_rotary_emb_chunk
,
compute_freqs
,
compute_freqs_dist
from
.utils
import
apply_
wan_rope_with_chunk
,
apply_wan_rope_with_flashinfer
,
apply_wan_rope_with_torch
class
WanTransformerInfer
(
BaseTransformerInfer
):
class
WanTransformerInfer
(
BaseTransformerInfer
):
...
@@ -20,11 +20,16 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -20,11 +20,16 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
head_dim
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
parallel_attention
=
None
self
.
parallel_attention
=
None
if
config
.
get
(
"rotary_chunk"
,
False
):
if
self
.
config
.
get
(
"rope_type"
,
"flashinfer"
)
==
"flashinfer"
:
chunk_size
=
config
.
get
(
"rotary_chunk_size"
,
100
)
if
self
.
config
.
get
(
"rope_chunk"
,
False
):
self
.
apply_rotary_emb_func
=
partial
(
apply_rotary_emb_chunk
,
chunk_size
=
chunk_size
)
self
.
apply_rope_func
=
partial
(
apply_wan_rope_with_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_flashinfer
)
else
:
self
.
apply_rope_func
=
apply_wan_rope_with_flashinfer
else
:
else
:
self
.
apply_rotary_emb_func
=
apply_rotary_emb
if
self
.
config
.
get
(
"rope_chunk"
,
False
):
self
.
apply_rope_func
=
partial
(
apply_wan_rope_with_chunk
,
chunk_size
=
self
.
config
.
get
(
"rope_chunk_size"
,
100
),
rope_func
=
apply_wan_rope_with_torch
)
else
:
self
.
apply_rope_func
=
apply_wan_rope_with_torch
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
...
@@ -35,21 +40,20 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -35,21 +40,20 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
seq_p_group
=
None
self
.
seq_p_group
=
None
self
.
infer_func
=
self
.
infer_without_offload
self
.
infer_func
=
self
.
infer_without_offload
self
.
cos_sin
=
None
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
def
_calculate_q_k_len
(
self
,
q
,
k_lens
):
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
q_lens
=
torch
.
tensor
([
q
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
q
.
device
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_q
=
torch
.
cat
([
q_lens
.
new_zeros
([
1
]),
q_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
cu_seqlens_k
=
torch
.
cat
([
k_lens
.
new_zeros
([
1
]),
k_lens
]).
cumsum
(
0
,
dtype
=
torch
.
int32
)
return
cu_seqlens_q
,
cu_seqlens_k
return
cu_seqlens_q
,
cu_seqlens_k
def
compute_freqs
(
self
,
q
,
grid_sizes
,
freqs
):
def
get_scheduler_values
(
self
):
if
self
.
config
[
"seq_parallel"
]:
self
.
cos_sin
=
self
.
scheduler
.
cos_sin
freqs_i
=
compute_freqs_dist
(
q
.
size
(
0
),
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
,
self
.
seq_p_group
)
else
:
freqs_i
=
compute_freqs
(
q
.
size
(
2
)
//
2
,
grid_sizes
,
freqs
)
return
freqs_i
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
self
.
get_scheduler_values
()
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
...
@@ -97,10 +101,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -97,10 +101,7 @@ class WanTransformerInfer(BaseTransformerInfer):
)
)
y_out
=
self
.
infer_self_attn
(
y_out
=
self
.
infer_self_attn
(
block
.
compute_phases
[
0
],
block
.
compute_phases
[
0
],
pre_infer_out
.
grid_sizes
.
tuple
,
x
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
shift_msa
,
scale_msa
,
scale_msa
,
)
)
...
@@ -129,7 +130,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -129,7 +130,8 @@ class WanTransformerInfer(BaseTransformerInfer):
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
infer_self_attn
(
self
,
phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
def
infer_self_attn
(
self
,
phase
,
x
,
shift_msa
,
scale_msa
):
cos_sin
=
self
.
cos_sin
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
...
@@ -153,16 +155,13 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -153,16 +155,13 @@ class WanTransformerInfer(BaseTransformerInfer):
k
=
phase
.
self_attn_norm_k
.
apply
(
phase
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
k
=
phase
.
self_attn_norm_k
.
apply
(
phase
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
v
=
phase
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
v
=
phase
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
q
,
k
=
self
.
apply_rope_func
(
q
,
k
,
cos_sin
)
q
=
self
.
apply_rotary_emb_func
(
q
,
freqs_i
)
k
=
self
.
apply_rotary_emb_func
(
k
,
freqs_i
)
k
_len
s
=
torch
.
empty_like
(
seq_lens
).
fill_
(
freqs_i
.
size
(
0
))
img_qkv
_len
=
q
.
shape
[
0
]
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
k_lens
=
k_lens
)
cu_seqlens_q
kv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
freqs_i
,
norm1_out
,
norm1_weight
,
norm1_bias
del
norm1_out
,
norm1_weight
,
norm1_bias
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
...
@@ -170,8 +169,8 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -170,8 +169,8 @@ class WanTransformerInfer(BaseTransformerInfer):
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
img_qkv_len
=
q
.
shape
[
0
]
,
img_qkv_len
=
img_qkv_len
,
cu_seqlens_qkv
=
cu_seqlens_q
,
cu_seqlens_qkv
=
cu_seqlens_q
kv
,
attention_module
=
phase
.
self_attn_1
,
attention_module
=
phase
.
self_attn_1
,
seq_p_group
=
self
.
seq_p_group
,
seq_p_group
=
self
.
seq_p_group
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
...
@@ -181,10 +180,10 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -181,10 +180,10 @@ class WanTransformerInfer(BaseTransformerInfer):
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
cu_seqlens_q
=
cu_seqlens_q
,
cu_seqlens_q
=
cu_seqlens_q
kv
,
cu_seqlens_kv
=
cu_seqlens_
k
,
cu_seqlens_kv
=
cu_seqlens_
qkv
,
max_seqlen_q
=
q
.
size
(
0
)
,
max_seqlen_q
=
img_qkv_len
,
max_seqlen_kv
=
k
.
size
(
0
)
,
max_seqlen_kv
=
img_qkv_len
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
47b3ce2f
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
flashinfer.rope
import
apply_rope_with_cos_sin_cache_inplace
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
def
apply_wan_rope_with_torch
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
n
=
xq
.
size
(
1
)
seq_len
=
cos_sin_cache
.
size
(
0
)
xq
=
torch
.
view_as_complex
(
xq
[:
seq_len
].
to
(
torch
.
float32
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
xk
=
torch
.
view_as_complex
(
xk
[:
seq_len
].
to
(
torch
.
float32
).
reshape
(
seq_len
,
n
,
-
1
,
2
))
# Apply rotary embedding
xq
=
torch
.
view_as_real
(
xq
*
cos_sin_cache
).
flatten
(
2
)
xk
=
torch
.
view_as_real
(
xk
*
cos_sin_cache
).
flatten
(
2
)
xq
=
torch
.
cat
([
xq
,
xq
[
seq_len
:]])
xk
=
torch
.
cat
([
xk
,
xk
[
seq_len
:]])
return
xq
.
to
(
GET_DTYPE
()),
xk
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_chunk
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
chunk_size
:
int
,
rope_func
,
):
seq_len
=
cos_sin_cache
.
size
(
0
)
xq_output_chunks
=
[]
xk_output_chunks
=
[]
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
end
=
min
(
start
+
chunk_size
,
seq_len
)
xq_chunk
=
xq
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
xq_chunk
,
xk_chunk
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
xq_output_chunks
.
append
(
xq_chunk
)
xk_output_chunks
.
append
(
xk_chunk
)
torch
.
cuda
.
empty_cache
()
x_q
=
torch
.
cat
(
xq_output_chunks
,
dim
=
0
)
del
xq_output_chunks
torch
.
cuda
.
empty_cache
()
x_k
=
torch
.
cat
(
xk_output_chunks
,
dim
=
0
)
del
xk_output_chunks
torch
.
cuda
.
empty_cache
()
return
x_q
.
to
(
GET_DTYPE
()),
x_k
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_flashinfer
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
cos_sin_cache
:
torch
.
Tensor
,
):
L
,
H
,
D
=
xq
.
shape
query
=
xq
.
reshape
(
L
,
H
*
D
).
contiguous
()
key
=
xk
.
reshape
(
L
,
H
*
D
).
contiguous
()
positions
=
torch
.
arange
(
L
,
device
=
"cpu"
,
dtype
=
torch
.
long
).
to
(
xq
.
device
,
non_blocking
=
True
)
apply_rope_with_cos_sin_cache_inplace
(
positions
=
positions
,
query
=
query
,
key
=
key
,
head_size
=
D
,
cos_sin_cache
=
cos_sin_cache
,
is_neox
=
False
,
)
xq_out
=
query
.
view
(
L
,
H
,
D
)
xk_out
=
key
.
view
(
L
,
H
,
D
)
return
xq_out
,
xk_out
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
def
compute_freqs
(
c
,
grid_sizes
,
freqs
):
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
freqs
=
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
f
,
h
,
w
=
grid_sizes
...
...
lightx2v/models/schedulers/wan/audio/scheduler.py
View file @
47b3ce2f
...
@@ -12,6 +12,8 @@ from lightx2v.utils.utils import masks_like
...
@@ -12,6 +12,8 @@ from lightx2v.utils.utils import masks_like
class
EulerScheduler
(
WanScheduler
):
class
EulerScheduler
(
WanScheduler
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
d
=
config
[
"dim"
]
//
config
[
"num_heads"
]
self
.
rope_t_dim
=
d
//
2
-
2
*
(
d
//
6
)
if
self
.
config
[
"parallel"
]:
if
self
.
config
[
"parallel"
]:
self
.
sp_size
=
self
.
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
self
.
sp_size
=
self
.
config
[
"parallel"
].
get
(
"seq_p_size"
,
1
)
...
@@ -83,6 +85,9 @@ class EulerScheduler(WanScheduler):
...
@@ -83,6 +85,9 @@ class EulerScheduler(WanScheduler):
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
self
.
timesteps
=
self
.
sigmas
*
self
.
num_train_timesteps
self
.
freqs
[
latent_shape
[
1
]
//
self
.
patch_size
[
0
]
:,
:
self
.
rope_t_dim
]
=
0
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
]
+
1
,
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
step_post
(
self
):
def
step_post
(
self
):
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
model_output
=
self
.
noise_pred
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
sample
=
self
.
latents
.
to
(
torch
.
float32
)
...
...
lightx2v/models/schedulers/wan/scheduler.py
View file @
47b3ce2f
...
@@ -14,6 +14,8 @@ class WanScheduler(BaseScheduler):
...
@@ -14,6 +14,8 @@ class WanScheduler(BaseScheduler):
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
infer_steps
=
self
.
config
[
"infer_steps"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
target_video_length
=
self
.
config
[
"target_video_length"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
sample_shift
=
self
.
config
[
"sample_shift"
]
self
.
run_device
=
self
.
config
.
get
(
"run_device"
,
"cuda"
)
self
.
patch_size
=
(
1
,
2
,
2
)
self
.
shift
=
1
self
.
shift
=
1
self
.
num_train_timesteps
=
1000
self
.
num_train_timesteps
=
1000
self
.
disable_corrector
=
[]
self
.
disable_corrector
=
[]
...
@@ -21,6 +23,24 @@ class WanScheduler(BaseScheduler):
...
@@ -21,6 +23,24 @@ class WanScheduler(BaseScheduler):
self
.
noise_pred
=
None
self
.
noise_pred
=
None
self
.
sample_guide_scale
=
self
.
config
[
"sample_guide_scale"
]
self
.
sample_guide_scale
=
self
.
config
[
"sample_guide_scale"
]
self
.
caching_records_2
=
[
True
]
*
self
.
config
[
"infer_steps"
]
self
.
caching_records_2
=
[
True
]
*
self
.
config
[
"infer_steps"
]
self
.
head_size
=
self
.
config
[
"dim"
]
//
self
.
config
[
"num_heads"
]
self
.
freqs
=
torch
.
cat
(
[
self
.
rope_params
(
1024
,
self
.
head_size
-
4
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
self
.
rope_params
(
1024
,
2
*
(
self
.
head_size
//
6
)),
],
dim
=
1
,
).
to
(
torch
.
device
(
self
.
run_device
))
def
rope_params
(
self
,
max_seq_len
,
dim
,
theta
=
10000
):
assert
dim
%
2
==
0
freqs
=
torch
.
outer
(
torch
.
arange
(
max_seq_len
),
1.0
/
torch
.
pow
(
theta
,
torch
.
arange
(
0
,
dim
,
2
).
to
(
torch
.
float32
).
div
(
dim
)),
)
freqs
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
def
prepare
(
self
,
seed
,
latent_shape
,
image_encoder_output
=
None
):
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
if
self
.
config
[
"model_cls"
]
==
"wan2.2"
and
self
.
config
[
"task"
]
in
[
"i2v"
,
"s2v"
]:
...
@@ -47,6 +67,31 @@ class WanScheduler(BaseScheduler):
...
@@ -47,6 +67,31 @@ class WanScheduler(BaseScheduler):
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
set_timesteps
(
self
.
infer_steps
,
device
=
self
.
device
,
shift
=
self
.
sample_shift
)
self
.
cos_sin
=
self
.
prepare_cos_sin
((
latent_shape
[
1
]
//
self
.
patch_size
[
0
],
latent_shape
[
2
]
//
self
.
patch_size
[
1
],
latent_shape
[
3
]
//
self
.
patch_size
[
2
]))
def
prepare_cos_sin
(
self
,
grid_sizes
):
c
=
self
.
head_size
//
2
freqs
=
self
.
freqs
.
split
([
c
-
2
*
(
c
//
3
),
c
//
3
,
c
//
3
],
dim
=
1
)
f
,
h
,
w
=
grid_sizes
seq_len
=
f
*
h
*
w
cos_sin
=
torch
.
cat
(
[
freqs
[
0
][:
f
].
view
(
f
,
1
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
1
][:
h
].
view
(
1
,
h
,
1
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
freqs
[
2
][:
w
].
view
(
1
,
1
,
w
,
-
1
).
expand
(
f
,
h
,
w
,
-
1
),
],
dim
=-
1
,
)
if
self
.
config
.
get
(
"rope_type"
,
"flashinfer"
)
==
"flashinfer"
:
cos_sin
=
cos_sin
.
reshape
(
seq_len
,
-
1
)
# Extract cos and sin parts separately and concatenate
cos_half
=
cos_sin
.
real
.
contiguous
()
sin_half
=
cos_sin
.
imag
.
contiguous
()
cos_sin
=
torch
.
cat
([
cos_half
,
sin_half
],
dim
=-
1
)
else
:
cos_sin
=
cos_sin
.
reshape
(
seq_len
,
1
,
-
1
)
return
cos_sin
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
def
prepare_latents
(
self
,
seed
,
latent_shape
,
dtype
=
torch
.
float32
):
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
generator
=
torch
.
Generator
(
device
=
self
.
device
).
manual_seed
(
seed
)
self
.
latents
=
torch
.
randn
(
self
.
latents
=
torch
.
randn
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment