Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
abeb9bc8
"doc/vscode:/vscode.git/clone" did not exist on "063c5489b349fe2d7f786c8196acc2bae5b24ce6"
Commit
abeb9bc8
authored
Aug 25, 2025
by
gushiqiao
Committed by
GitHub
Aug 25, 2025
Browse files
[Feat] Support vace offload and recon offload. (#245)
parent
87343386
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
245 additions
and
216 deletions
+245
-216
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+14
-8
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+6
-0
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+155
-112
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+51
-45
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
+14
-22
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+2
-25
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
...v/models/networks/wan/weights/vace/transformer_weights.py
+3
-4
No files found.
lightx2v/common/offload/manager.py
View file @
abeb9bc8
...
@@ -10,21 +10,27 @@ from loguru import logger
...
@@ -10,21 +10,27 @@ from loguru import logger
class
WeightAsyncStreamManager
(
object
):
class
WeightAsyncStreamManager
(
object
):
def
__init__
(
self
,
blocks_num
,
offload_ratio
=
1
,
phases_num
=
1
):
def
__init__
(
self
,
blocks_num
,
offload_ratio
=
1
,
phases_num
=
1
):
self
.
active_weights
=
[
None
for
_
in
range
(
3
)]
self
.
init
(
blocks_num
,
phases_num
,
offload_ratio
)
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=-
1
)
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=-
1
)
self
.
cpu_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
cpu_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
cuda_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
cuda_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
offload_block_num
=
int
(
offload_ratio
*
blocks_num
)
def
init
(
self
,
blocks_num
,
phases_num
,
offload_ratio
):
if
hasattr
(
self
,
"active_weights"
):
del
self
.
active_weights
[:]
self
.
active_weights
=
[
None
for
_
in
range
(
3
)]
self
.
blocks_num
=
blocks_num
self
.
phases_num
=
phases_num
self
.
phases_num
=
phases_num
self
.
block_nums
=
blocks_num
self
.
offload_ratio
=
offload_ratio
self
.
offload_phases_num
=
blocks_num
*
phases_num
*
offload_ratio
self
.
offload_blocks_num
=
int
(
self
.
offload_ratio
*
self
.
blocks_num
)
self
.
offload_phases_num
=
self
.
blocks_num
*
self
.
phases_num
*
self
.
offload_ratio
def
prefetch_weights
(
self
,
block_idx
,
blocks_weights
):
def
prefetch_weights
(
self
,
block_idx
,
blocks_weights
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
self
.
active_weights
[
2
]
=
blocks_weights
[
block_idx
]
self
.
active_weights
[
2
]
=
blocks_weights
[
block_idx
]
self
.
active_weights
[
2
].
to_cuda_async
()
self
.
active_weights
[
2
].
to_cuda_async
()
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
if
block_idx
<
self
.
offload_block_num
:
if
block_idx
<
self
.
offload_block
s
_num
:
if
self
.
active_weights
[
1
]
is
not
None
:
if
self
.
active_weights
[
1
]
is
not
None
:
self
.
active_weights
[
1
].
to_cpu_async
()
self
.
active_weights
[
1
].
to_cpu_async
()
...
@@ -130,7 +136,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -130,7 +136,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
if
next_block_idx
<
0
:
if
next_block_idx
<
0
:
next_block_idx
=
0
next_block_idx
=
0
if
next_block_idx
==
self
.
block_num
s
:
if
next_block_idx
==
self
.
block
s
_num
:
return
return
if
self
.
offload_gra
==
"phase"
:
if
self
.
offload_gra
==
"phase"
:
...
@@ -175,7 +181,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -175,7 +181,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self
.
pin_memory_buffer
.
push
(
block_idx
,
block
)
self
.
pin_memory_buffer
.
push
(
block_idx
,
block
)
block_idx
+=
1
block_idx
+=
1
if
block_idx
==
self
.
block_num
s
:
if
block_idx
==
self
.
block
s
_num
:
break
break
def
prefetch_weights_from_disk
(
self
,
blocks
):
def
prefetch_weights_from_disk
(
self
,
blocks
):
...
@@ -217,7 +223,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
...
@@ -217,7 +223,7 @@ class LazyWeightAsyncStreamManager(WeightAsyncStreamManager):
self
.
active_weights
[
2
]
=
(
obj_key
,
block
)
self
.
active_weights
[
2
]
=
(
obj_key
,
block
)
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
if
block_idx
<
self
.
offload_block_num
:
if
block_idx
<
self
.
offload_block
s
_num
:
if
self
.
active_weights
[
1
]
is
not
None
:
if
self
.
active_weights
[
1
]
is
not
None
:
old_key
,
old_block
=
self
.
active_weights
[
1
]
old_key
,
old_block
=
self
.
active_weights
[
1
]
if
self
.
pin_memory_buffer
.
exists
(
old_key
):
if
self
.
pin_memory_buffer
.
exists
(
old_key
):
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
abeb9bc8
...
@@ -95,6 +95,12 @@ class MMWeight(MMWeightTemplate):
...
@@ -95,6 +95,12 @@ class MMWeight(MMWeightTemplate):
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
self
.
pinned_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
if
self
.
bias
is
not
None
else
None
self
.
pinned_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
if
self
.
bias
is
not
None
else
None
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
dtype
=
input_tensor
.
dtype
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
abeb9bc8
...
@@ -4,8 +4,7 @@ from lightx2v.common.offload.manager import (
...
@@ -4,8 +4,7 @@ from lightx2v.common.offload.manager import (
LazyWeightAsyncStreamManager
,
LazyWeightAsyncStreamManager
,
WeightAsyncStreamManager
,
WeightAsyncStreamManager
,
)
)
from
lightx2v.models.networks.wan.infer.transformer_infer
import
WanTransformerInfer
from
..transformer_infer
import
WanTransformerInfer
class
WanOffloadTransformerInfer
(
WanTransformerInfer
):
class
WanOffloadTransformerInfer
(
WanTransformerInfer
):
...
@@ -13,20 +12,31 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -13,20 +12,31 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
self
.
config
.
get
(
"cpu_offload"
,
False
):
if
"offload_ratio"
in
self
.
config
:
if
"offload_ratio"
in
self
.
config
:
offload_ratio
=
self
.
config
[
"offload_ratio"
]
self
.
offload_ratio
=
self
.
config
[
"offload_ratio"
]
else
:
else
:
offload_ratio
=
1
self
.
offload_ratio
=
1
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
if
offload_granularity
==
"block"
:
if
offload_granularity
==
"block"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_offload
self
.
infer_func
=
self
.
infer_with_
blocks_
offload
else
:
else
:
self
.
infer_func
=
self
.
infer_with_lazy_offload
self
.
infer_func
=
self
.
infer_with_
blocks_
lazy_offload
elif
offload_granularity
==
"phase"
:
elif
offload_granularity
==
"phase"
:
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
infer_func
=
self
.
infer_with_phases_offload
self
.
infer_func
=
self
.
infer_with_phases_offload
else
:
else
:
self
.
infer_func
=
self
.
infer_with_phases_lazy_offload
self
.
infer_func
=
self
.
infer_with_phases_lazy_offload
self
.
phase_params
=
{
"shift_msa"
:
None
,
"scale_msa"
:
None
,
"gate_msa"
:
None
,
"c_shift_msa"
:
None
,
"c_scale_msa"
:
None
,
"c_gate_msa"
:
None
,
"y_out"
:
None
,
"attn_out"
:
None
,
"y"
:
None
,
}
elif
offload_granularity
==
"model"
:
elif
offload_granularity
==
"model"
:
self
.
infer_func
=
self
.
infer_without_offload
self
.
infer_func
=
self
.
infer_without_offload
...
@@ -34,168 +44,201 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -34,168 +44,201 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
):
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
(
self
.
weights_stream_mgr
=
WeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
offload_ratio
=
self
.
offload_ratio
,
phases_num
=
self
.
phases_num
,
phases_num
=
self
.
phases_num
,
)
)
else
:
else
:
self
.
weights_stream_mgr
=
LazyWeightAsyncStreamManager
(
self
.
weights_stream_mgr
=
LazyWeightAsyncStreamManager
(
blocks_num
=
self
.
blocks_num
,
blocks_num
=
self
.
blocks_num
,
offload_ratio
=
offload_ratio
,
offload_ratio
=
self
.
offload_ratio
,
phases_num
=
self
.
phases_num
,
phases_num
=
self
.
phases_num
,
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
num_disk_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
max_memory
=
self
.
config
.
get
(
"max_memory"
,
2
),
offload_gra
=
offload_granularity
,
offload_gra
=
offload_granularity
,
)
)
def
infer_with_offload
(
self
,
weight
s
,
x
,
pre_infer_out
):
def
infer_with_
blocks_
offload
(
self
,
block
s
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
self
.
blocks
_num
):
for
block_idx
in
range
(
len
(
blocks
)
):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
if
block_idx
==
0
:
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
weights
.
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
blocks
[
0
]
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
].
to_cuda
()
if
block_idx
<
self
.
blocks
_num
-
1
:
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
weights_stream_mgr
.
swap_weights
()
self
.
weights_stream_mgr
.
swap_weights
()
return
x
return
x
def
infer_with_lazy_offload
(
self
,
weight
s
,
x
,
pre_infer_out
):
def
infer_with_
blocks_
lazy_offload
(
self
,
block
s
,
x
,
pre_infer_out
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
blocks
)
for
block_idx
in
range
(
self
.
blocks
_num
):
for
block_idx
in
range
(
len
(
blocks
)
):
if
block_idx
==
0
:
if
block_idx
==
0
:
block
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
block_idx
)
block
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
block_idx
)
block
.
to_cuda
()
block
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
block_idx
,
block
)
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
block_idx
,
block
)
if
block_idx
<
self
.
blocks
_num
-
1
:
if
block_idx
<
len
(
blocks
)
-
1
:
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
prefetch_weights
(
block_idx
+
1
,
blocks
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
self
.
weights_stream_mgr
.
swap_weights
()
self
.
weights_stream_mgr
.
swap_weights
()
if
block_idx
==
self
.
blocks
_num
-
1
:
if
block_idx
==
len
(
blocks
)
-
1
:
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
pin_memory_buffer
.
pop_front
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
self
.
weights_stream_mgr
.
_async_prefetch_block
(
blocks
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
del
(
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
x
return
x
def
infer_with_phases_offload
(
self
,
weight
s
,
x
,
pre_infer_out
):
def
infer_with_phases_offload
(
self
,
block
s
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
weights
.
blocks
_num
):
for
block_idx
in
range
(
len
(
blocks
)
):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
phases_num
):
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
False
)
if
block_idx
==
0
and
phase_idx
==
0
:
phase
=
weights
.
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
phase_idx
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
cur_phase_idx
,
cur_phase
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
pre_infer_out
.
embed0
)
elif
cur_phase_idx
==
1
:
y_out
=
self
.
infer_self_attn
(
cur_phase
,
pre_infer_out
.
grid_sizes
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
is_last_phase
=
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
if
not
is_last_phase
:
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
attn_out
,
y_out
,
y
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
self
.
clear_offload_params
(
pre_infer_out
)
del
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
torch
.
cuda
.
empty_cache
()
return
x
return
x
def
infer_with_phases_lazy_offload
(
self
,
weight
s
,
x
,
pre_infer_out
):
def
infer_with_phases_lazy_offload
(
self
,
block
s
,
x
,
pre_infer_out
):
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
weights
.
blocks
)
self
.
weights_stream_mgr
.
prefetch_weights_from_disk
(
blocks
)
for
block_idx
in
range
(
weights
.
blocks
_num
):
for
block_idx
in
range
(
len
(
blocks
)
):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
for
phase_idx
in
range
(
self
.
weights_stream_mgr
.
phases_num
):
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
,
True
)
if
block_idx
==
0
and
phase_idx
==
0
:
self
.
weights_stream_mgr
.
_async_prefetch_block
(
blocks
)
if
self
.
clean_cuda_cache
:
del
(
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"y"
],
)
torch
.
cuda
.
empty_cache
()
if
self
.
clean_cuda_cache
:
self
.
clear_offload_params
(
pre_infer_out
)
return
x
def
infer_phases
(
self
,
block_idx
,
blocks
,
x
,
pre_infer_out
,
lazy
):
for
phase_idx
in
range
(
self
.
phases_num
):
if
block_idx
==
0
and
phase_idx
==
0
:
if
lazy
:
obj_key
=
(
block_idx
,
phase_idx
)
obj_key
=
(
block_idx
,
phase_idx
)
phase
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
obj_key
)
phase
=
self
.
weights_stream_mgr
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda
()
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
obj_key
,
phase
)
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
obj_key
,
phase
)
else
:
phase
=
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
phase
.
to_cuda
()
self
.
weights_stream_mgr
.
active_weights
[
0
]
=
(
phase_idx
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
with
torch
.
cuda
.
stream
(
self
.
weights_stream_mgr
.
compute_stream
):
(
x
=
self
.
infer_phase
(
self
.
weights_stream_mgr
.
active_weights
[
0
],
x
,
pre_infer_out
)
(
_
,
cur_phase_idx
,
),
cur_phase
,
)
=
self
.
weights_stream_mgr
.
active_weights
[
0
]
if
cur_phase_idx
==
0
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
cur_phase
,
pre_infer_out
.
embed0
)
elif
cur_phase_idx
==
1
:
y_out
=
self
.
infer_self_attn
(
cur_phase
,
pre_infer_out
.
grid_sizes
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
shift_msa
,
scale_msa
,
)
elif
cur_phase_idx
==
2
:
x
,
attn_out
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
)
elif
cur_phase_idx
==
3
:
y
=
self
.
infer_ffn
(
cur_phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
if
not
(
block_idx
==
weights
.
blocks_num
-
1
and
phase_idx
==
self
.
phases_num
-
1
):
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
weights_stream_mgr
.
phases_num
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
weights
.
blocks
)
self
.
weights_stream_mgr
.
swap_phases
()
self
.
weights_stream_mgr
.
_async_prefetch_block
(
weights
.
blocks
)
if
self
.
clean_cuda_cache
:
is_last_phase
=
block_idx
==
len
(
blocks
)
-
1
and
phase_idx
==
self
.
phases_num
-
1
del
attn_out
,
y_out
,
y
if
not
is_last_phase
:
torch
.
cuda
.
empty_cache
()
next_block_idx
=
block_idx
+
1
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
self
.
weights_stream_mgr
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
if
self
.
clean_cuda_cache
:
self
.
weights_stream_mgr
.
swap_phases
()
del
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
del
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
return
x
torch
.
cuda
.
empty_cache
()
def
infer_phase
(
self
,
active_weight
,
x
,
pre_infer_out
):
if
not
self
.
config
.
get
(
"lazy_load"
):
cur_phase_idx
,
cur_phase
=
active_weight
else
:
(
_
,
cur_phase_idx
),
cur_phase
=
active_weight
if
cur_phase_idx
==
0
:
if
hasattr
(
cur_phase
,
"before_proj"
):
x
=
cur_phase
.
before_proj
.
apply
(
x
)
+
pre_infer_out
.
x
(
self
.
phase_params
[
"shift_msa"
],
self
.
phase_params
[
"scale_msa"
],
self
.
phase_params
[
"gate_msa"
],
self
.
phase_params
[
"c_shift_msa"
],
self
.
phase_params
[
"c_scale_msa"
],
self
.
phase_params
[
"c_gate_msa"
],
)
=
self
.
pre_process
(
cur_phase
.
modulation
,
pre_infer_out
.
embed0
)
self
.
phase_params
[
"y_out"
]
=
self
.
infer_self_attn
(
cur_phase
,
pre_infer_out
.
grid_sizes
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
self
.
phase_params
[
"shift_msa"
],
self
.
phase_params
[
"scale_msa"
],
)
elif
cur_phase_idx
==
1
:
x
,
self
.
phase_params
[
"attn_out"
]
=
self
.
infer_cross_attn
(
cur_phase
,
x
,
pre_infer_out
.
context
,
self
.
phase_params
[
"y_out"
],
self
.
phase_params
[
"gate_msa"
],
)
elif
cur_phase_idx
==
2
:
self
.
phase_params
[
"y"
]
=
self
.
infer_ffn
(
cur_phase
,
x
,
self
.
phase_params
[
"attn_out"
],
self
.
phase_params
[
"c_shift_msa"
],
self
.
phase_params
[
"c_scale_msa"
],
)
x
=
self
.
post_process
(
x
,
self
.
phase_params
[
"y"
],
self
.
phase_params
[
"c_gate_msa"
],
pre_infer_out
,
)
if
hasattr
(
cur_phase
,
"after_proj"
):
pre_infer_out
.
adapter_output
[
"hints"
].
append
(
cur_phase
.
after_proj
.
apply
(
x
))
return
x
return
x
def
clear_offload_params
(
self
,
pre_infer_out
):
del
(
self
.
phase_params
[
"shift_msa"
],
self
.
phase_params
[
"scale_msa"
],
self
.
phase_params
[
"gate_msa"
],
)
del
(
self
.
phase_params
[
"c_shift_msa"
],
self
.
phase_params
[
"c_scale_msa"
],
self
.
phase_params
[
"c_gate_msa"
],
)
del
(
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
embed0
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
freqs
,
pre_infer_out
.
context
,
)
torch
.
cuda
.
empty_cache
()
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
abeb9bc8
...
@@ -14,7 +14,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -14,7 +14,7 @@ class WanTransformerInfer(BaseTransformerInfer):
self
.
task
=
config
.
task
self
.
task
=
config
.
task
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
attention_type
=
config
.
get
(
"attention_type"
,
"flash_attn2"
)
self
.
blocks_num
=
config
.
num_layers
self
.
blocks_num
=
config
.
num_layers
self
.
phases_num
=
4
self
.
phases_num
=
3
self
.
num_heads
=
config
.
num_heads
self
.
num_heads
=
config
.
num_heads
self
.
head_dim
=
config
.
dim
//
config
.
num_heads
self
.
head_dim
=
config
.
dim
//
config
.
num_heads
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
self
.
window_size
=
config
.
get
(
"window_size"
,
(
-
1
,
-
1
))
...
@@ -49,11 +49,11 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -49,11 +49,11 @@ class WanTransformerInfer(BaseTransformerInfer):
return
freqs_i
return
freqs_i
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
x
=
self
.
infer_main_blocks
(
weights
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
infer_main_blocks
(
self
,
weight
s
,
pre_infer_out
):
def
infer_main_blocks
(
self
,
block
s
,
pre_infer_out
):
x
=
self
.
infer_func
(
weight
s
,
pre_infer_out
.
x
,
pre_infer_out
)
x
=
self
.
infer_func
(
block
s
,
pre_infer_out
.
x
,
pre_infer_out
)
return
x
return
x
def
infer_non_blocks
(
self
,
weights
,
x
,
e
):
def
infer_non_blocks
(
self
,
weights
,
x
,
e
):
...
@@ -80,19 +80,22 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -80,19 +80,22 @@ class WanTransformerInfer(BaseTransformerInfer):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
x
return
x
def
infer_without_offload
(
self
,
weight
s
,
x
,
pre_infer_out
):
def
infer_without_offload
(
self
,
block
s
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
self
.
blocks
_num
):
for
block_idx
in
range
(
len
(
blocks
)
):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
x
=
self
.
infer_block
(
weights
.
blocks
[
block_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_block
(
blocks
[
block_idx
],
x
,
pre_infer_out
)
return
x
return
x
def
infer_block
(
self
,
weights
,
x
,
pre_infer_out
):
def
infer_block
(
self
,
block
,
x
,
pre_infer_out
):
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
infer_modulation
(
if
hasattr
(
block
.
compute_phases
[
0
],
"before_proj"
):
weights
.
compute_phases
[
0
],
x
=
block
.
compute_phases
[
0
].
before_proj
.
apply
(
x
)
+
pre_infer_out
.
x
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
self
.
pre_process
(
block
.
compute_phases
[
0
].
modulation
,
pre_infer_out
.
embed0
,
pre_infer_out
.
embed0
,
)
)
y_out
=
self
.
infer_self_attn
(
y_out
=
self
.
infer_self_attn
(
weights
.
compute_phases
[
1
],
block
.
compute_phases
[
0
],
pre_infer_out
.
grid_sizes
,
pre_infer_out
.
grid_sizes
,
x
,
x
,
pre_infer_out
.
seq_lens
,
pre_infer_out
.
seq_lens
,
...
@@ -100,18 +103,21 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -100,18 +103,21 @@ class WanTransformerInfer(BaseTransformerInfer):
shift_msa
,
shift_msa
,
scale_msa
,
scale_msa
,
)
)
x
,
attn_out
=
self
.
infer_cross_attn
(
weights
.
compute_phases
[
2
],
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
)
x
,
attn_out
=
self
.
infer_cross_attn
(
block
.
compute_phases
[
1
],
x
,
pre_infer_out
.
context
,
y_out
,
gate_msa
)
y
=
self
.
infer_ffn
(
weights
.
compute_phases
[
3
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
y
=
self
.
infer_ffn
(
block
.
compute_phases
[
2
],
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
x
=
self
.
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
if
hasattr
(
block
.
compute_phases
[
2
],
"after_proj"
):
pre_infer_out
.
adapter_output
[
"hints"
].
append
(
block
.
compute_phases
[
2
].
after_proj
.
apply
(
x
))
return
x
return
x
def
infer_modulation
(
self
,
weights
,
embed0
):
def
pre_process
(
self
,
modulation
,
embed0
):
if
embed0
.
dim
()
==
3
and
embed0
.
shape
[
2
]
==
1
:
if
embed0
.
dim
()
==
3
and
embed0
.
shape
[
2
]
==
1
:
modulation
=
weights
.
modulation
.
tensor
.
unsqueeze
(
2
)
modulation
=
modulation
.
tensor
.
unsqueeze
(
2
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
embed0
=
(
modulation
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
[
ei
.
squeeze
(
1
)
for
ei
in
embed0
]
else
:
else
:
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
weights
.
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
=
(
modulation
.
tensor
+
embed0
).
chunk
(
6
,
dim
=
1
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
embed0
del
embed0
...
@@ -119,15 +125,15 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -119,15 +125,15 @@ class WanTransformerInfer(BaseTransformerInfer):
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
return
shift_msa
,
scale_msa
,
gate_msa
,
c_shift_msa
,
c_scale_msa
,
c_gate_msa
def
infer_self_attn
(
self
,
weights
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
def
infer_self_attn
(
self
,
phase
,
grid_sizes
,
x
,
seq_lens
,
freqs
,
shift_msa
,
scale_msa
):
if
hasattr
(
weights
,
"smooth_norm1_weight"
):
if
hasattr
(
phase
,
"smooth_norm1_weight"
):
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
weights
.
smooth_norm1_weight
.
tensor
norm1_weight
=
(
1
+
scale_msa
.
squeeze
())
*
phase
.
smooth_norm1_weight
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
weights
.
smooth_norm1_bias
.
tensor
norm1_bias
=
shift_msa
.
squeeze
()
*
phase
.
smooth_norm1_bias
.
tensor
else
:
else
:
norm1_weight
=
1
+
scale_msa
.
squeeze
()
norm1_weight
=
1
+
scale_msa
.
squeeze
()
norm1_bias
=
shift_msa
.
squeeze
()
norm1_bias
=
shift_msa
.
squeeze
()
norm1_out
=
weights
.
norm1
.
apply
(
x
)
norm1_out
=
phase
.
norm1
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
norm1_out
=
norm1_out
.
to
(
self
.
sensitive_layer_dtype
)
...
@@ -139,9 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -139,9 +145,9 @@ class WanTransformerInfer(BaseTransformerInfer):
s
,
n
,
d
=
*
norm1_out
.
shape
[:
1
],
self
.
num_heads
,
self
.
head_dim
s
,
n
,
d
=
*
norm1_out
.
shape
[:
1
],
self
.
num_heads
,
self
.
head_dim
q
=
weights
.
self_attn_norm_q
.
apply
(
weights
.
self_attn_q
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
q
=
phase
.
self_attn_norm_q
.
apply
(
phase
.
self_attn_q
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
k
=
weights
.
self_attn_norm_k
.
apply
(
weights
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
k
=
phase
.
self_attn_norm_k
.
apply
(
phase
.
self_attn_k
.
apply
(
norm1_out
)).
view
(
s
,
n
,
d
)
v
=
weights
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
v
=
phase
.
self_attn_v
.
apply
(
norm1_out
).
view
(
s
,
n
,
d
)
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
freqs_i
=
self
.
compute_freqs
(
q
,
grid_sizes
,
freqs
)
...
@@ -156,18 +162,18 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -156,18 +162,18 @@ class WanTransformerInfer(BaseTransformerInfer):
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
attn_out
=
weights
.
self_attn_1_parallel
.
apply
(
attn_out
=
phase
.
self_attn_1_parallel
.
apply
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
img_qkv_len
=
q
.
shape
[
0
],
img_qkv_len
=
q
.
shape
[
0
],
cu_seqlens_qkv
=
cu_seqlens_q
,
cu_seqlens_qkv
=
cu_seqlens_q
,
attention_module
=
weights
.
self_attn_1
,
attention_module
=
phase
.
self_attn_1
,
seq_p_group
=
self
.
seq_p_group
,
seq_p_group
=
self
.
seq_p_group
,
model_cls
=
self
.
config
[
"model_cls"
],
model_cls
=
self
.
config
[
"model_cls"
],
)
)
else
:
else
:
attn_out
=
weights
.
self_attn_1
.
apply
(
attn_out
=
phase
.
self_attn_1
.
apply
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
...
@@ -179,7 +185,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -179,7 +185,7 @@ class WanTransformerInfer(BaseTransformerInfer):
mask_map
=
self
.
mask_map
,
mask_map
=
self
.
mask_map
,
)
)
y
=
weights
.
self_attn_o
.
apply
(
attn_out
)
y
=
phase
.
self_attn_o
.
apply
(
attn_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
q
,
k
,
v
,
attn_out
del
q
,
k
,
v
,
attn_out
...
@@ -187,13 +193,13 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -187,13 +193,13 @@ class WanTransformerInfer(BaseTransformerInfer):
return
y
return
y
def
infer_cross_attn
(
self
,
weights
,
x
,
context
,
y_out
,
gate_msa
):
def
infer_cross_attn
(
self
,
phase
,
x
,
context
,
y_out
,
gate_msa
):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y_out
.
to
(
self
.
sensitive_layer_dtype
)
*
gate_msa
.
squeeze
()
x
=
x
.
to
(
self
.
sensitive_layer_dtype
)
+
y_out
.
to
(
self
.
sensitive_layer_dtype
)
*
gate_msa
.
squeeze
()
else
:
else
:
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
x
.
add_
(
y_out
*
gate_msa
.
squeeze
())
norm3_out
=
weights
.
norm3
.
apply
(
x
)
norm3_out
=
phase
.
norm3
.
apply
(
x
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
):
context_img
=
context
[:
257
]
context_img
=
context
[:
257
]
context
=
context
[
257
:]
context
=
context
[
257
:]
...
@@ -207,14 +213,14 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -207,14 +213,14 @@ class WanTransformerInfer(BaseTransformerInfer):
n
,
d
=
self
.
num_heads
,
self
.
head_dim
n
,
d
=
self
.
num_heads
,
self
.
head_dim
q
=
weights
.
cross_attn_norm_q
.
apply
(
weights
.
cross_attn_q
.
apply
(
norm3_out
)).
view
(
-
1
,
n
,
d
)
q
=
phase
.
cross_attn_norm_q
.
apply
(
phase
.
cross_attn_q
.
apply
(
norm3_out
)).
view
(
-
1
,
n
,
d
)
k
=
weights
.
cross_attn_norm_k
.
apply
(
weights
.
cross_attn_k
.
apply
(
context
)).
view
(
-
1
,
n
,
d
)
k
=
phase
.
cross_attn_norm_k
.
apply
(
phase
.
cross_attn_k
.
apply
(
context
)).
view
(
-
1
,
n
,
d
)
v
=
weights
.
cross_attn_v
.
apply
(
context
).
view
(
-
1
,
n
,
d
)
v
=
phase
.
cross_attn_v
.
apply
(
context
).
view
(
-
1
,
n
,
d
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
q
,
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
k_lens
=
torch
.
tensor
([
k
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
)
attn_out
=
weights
.
cross_attn_1
.
apply
(
attn_out
=
phase
.
cross_attn_1
.
apply
(
q
=
q
,
q
=
q
,
k
=
k
,
k
=
k
,
v
=
v
,
v
=
v
,
...
@@ -226,14 +232,14 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -226,14 +232,14 @@ class WanTransformerInfer(BaseTransformerInfer):
)
)
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
if
self
.
task
in
[
"i2v"
,
"flf2v"
]
and
self
.
config
.
get
(
"use_image_encoder"
,
True
)
and
context_img
is
not
None
:
k_img
=
weights
.
cross_attn_norm_k_img
.
apply
(
weights
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
k_img
=
phase
.
cross_attn_norm_k_img
.
apply
(
phase
.
cross_attn_k_img
.
apply
(
context_img
)).
view
(
-
1
,
n
,
d
)
v_img
=
weights
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
v_img
=
phase
.
cross_attn_v_img
.
apply
(
context_img
).
view
(
-
1
,
n
,
d
)
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
cu_seqlens_q
,
cu_seqlens_k
=
self
.
_calculate_q_k_len
(
q
,
q
,
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
k_lens
=
torch
.
tensor
([
k_img
.
size
(
0
)],
dtype
=
torch
.
int32
,
device
=
k
.
device
),
)
)
img_attn_out
=
weights
.
cross_attn_2
.
apply
(
img_attn_out
=
phase
.
cross_attn_2
.
apply
(
q
=
q
,
q
=
q
,
k
=
k_img
,
k
=
k_img
,
v
=
v_img
,
v
=
v_img
,
...
@@ -249,42 +255,42 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -249,42 +255,42 @@ class WanTransformerInfer(BaseTransformerInfer):
del
k_img
,
v_img
,
img_attn_out
del
k_img
,
v_img
,
img_attn_out
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
attn_out
=
weights
.
cross_attn_o
.
apply
(
attn_out
)
attn_out
=
phase
.
cross_attn_o
.
apply
(
attn_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
q
,
k
,
v
,
norm3_out
,
context
,
context_img
del
q
,
k
,
v
,
norm3_out
,
context
,
context_img
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
x
,
attn_out
return
x
,
attn_out
def
infer_ffn
(
self
,
weights
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
):
def
infer_ffn
(
self
,
phase
,
x
,
attn_out
,
c_shift_msa
,
c_scale_msa
):
x
.
add_
(
attn_out
)
x
.
add_
(
attn_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
attn_out
del
attn_out
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
hasattr
(
weights
,
"smooth_norm2_weight"
):
if
hasattr
(
phase
,
"smooth_norm2_weight"
):
norm2_weight
=
(
1
+
c_scale_msa
.
squeeze
())
*
weights
.
smooth_norm2_weight
.
tensor
norm2_weight
=
(
1
+
c_scale_msa
.
squeeze
())
*
phase
.
smooth_norm2_weight
.
tensor
norm2_bias
=
c_shift_msa
.
squeeze
()
*
weights
.
smooth_norm2_bias
.
tensor
norm2_bias
=
c_shift_msa
.
squeeze
()
*
phase
.
smooth_norm2_bias
.
tensor
else
:
else
:
norm2_weight
=
1
+
c_scale_msa
.
squeeze
()
norm2_weight
=
1
+
c_scale_msa
.
squeeze
()
norm2_bias
=
c_shift_msa
.
squeeze
()
norm2_bias
=
c_shift_msa
.
squeeze
()
norm2_out
=
weights
.
norm2
.
apply
(
x
)
norm2_out
=
phase
.
norm2
.
apply
(
x
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
=
norm2_out
.
to
(
self
.
sensitive_layer_dtype
)
norm2_out
.
mul_
(
norm2_weight
).
add_
(
norm2_bias
)
norm2_out
.
mul_
(
norm2_weight
).
add_
(
norm2_bias
)
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
norm2_out
=
norm2_out
.
to
(
self
.
infer_dtype
)
norm2_out
=
norm2_out
.
to
(
self
.
infer_dtype
)
y
=
weights
.
ffn_0
.
apply
(
norm2_out
)
y
=
phase
.
ffn_0
.
apply
(
norm2_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
norm2_out
,
x
,
norm2_weight
,
norm2_bias
del
norm2_out
,
x
,
norm2_weight
,
norm2_bias
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
y
=
weights
.
ffn_2
.
apply
(
y
)
y
=
phase
.
ffn_2
.
apply
(
y
)
return
y
return
y
...
...
lightx2v/models/networks/wan/infer/vace/transformer_infer.py
View file @
abeb9bc8
...
@@ -5,41 +5,33 @@ from lightx2v.utils.envs import *
...
@@ -5,41 +5,33 @@ from lightx2v.utils.envs import *
class
WanVaceTransformerInfer
(
WanOffloadTransformerInfer
):
class
WanVaceTransformerInfer
(
WanOffloadTransformerInfer
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
self
.
vace_block_num
s
=
len
(
self
.
config
.
vace_layers
)
self
.
vace_block
s
_num
=
len
(
self
.
config
.
vace_layers
)
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
.
vace_layers
)}
self
.
vace_blocks_mapping
=
{
orig_idx
:
seq_idx
for
seq_idx
,
orig_idx
in
enumerate
(
self
.
config
.
vace_layers
)}
def
infer
(
self
,
weights
,
pre_infer_out
):
def
infer
(
self
,
weights
,
pre_infer_out
):
pre_infer_out
.
adapter_output
[
"hints"
]
=
self
.
infer_vace
(
weights
,
pre_infer_out
)
pre_infer_out
.
c
=
self
.
vace_pre_process
(
weights
.
vace_patch_embedding
,
pre_infer_out
.
vace_context
)
x
=
self
.
infer_main_blocks
(
weights
,
pre_infer_out
)
self
.
infer_vace_blocks
(
weights
.
vace_blocks
,
pre_infer_out
)
x
=
self
.
infer_main_blocks
(
weights
.
blocks
,
pre_infer_out
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
return
self
.
infer_non_blocks
(
weights
,
x
,
pre_infer_out
.
embed
)
def
infer_vace
(
self
,
weights
,
pre_infer_ou
t
):
def
vace_pre_process
(
self
,
patch_embedding
,
vace_contex
t
):
c
=
weights
.
vace_
patch_embedding
.
apply
(
pre_infer_out
.
vace_context
.
unsqueeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
c
=
patch_embedding
.
apply
(
vace_context
.
unsqueeze
(
0
).
to
(
self
.
sensitive_layer_dtype
))
c
=
c
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
().
squeeze
(
0
)
c
=
c
.
flatten
(
2
).
transpose
(
1
,
2
).
contiguous
().
squeeze
(
0
)
return
c
def
infer_vace_blocks
(
self
,
vace_blocks
,
pre_infer_out
):
pre_infer_out
.
adapter_output
[
"hints"
]
=
[]
self
.
infer_state
=
"vace"
self
.
infer_state
=
"vace"
hints
=
[]
if
hasattr
(
self
,
"weights_stream_mgr"
):
self
.
weights_stream_mgr
.
init
(
self
.
vace_blocks_num
,
self
.
phases_num
,
self
.
offload_ratio
)
for
i
in
range
(
self
.
vace_block_nums
):
self
.
infer_func
(
vace_blocks
,
pre_infer_out
.
c
,
pre_infer_out
)
c
,
c_skip
=
self
.
infer_vace_block
(
weights
.
vace_blocks
[
i
],
c
,
pre_infer_out
.
x
,
pre_infer_out
)
hints
.
append
(
c_skip
)
self
.
infer_state
=
"base"
self
.
infer_state
=
"base"
return
hints
if
hasattr
(
self
,
"weights_stream_mgr"
):
self
.
weights_stream_mgr
.
init
(
self
.
blocks_num
,
self
.
phases_num
,
self
.
offload_ratio
)
def
infer_vace_block
(
self
,
weights
,
c
,
x
,
pre_infer_out
):
if
hasattr
(
weights
,
"before_proj"
):
c
=
weights
.
before_proj
.
apply
(
c
)
+
x
c
=
self
.
infer_block
(
weights
,
c
,
pre_infer_out
)
c_skip
=
weights
.
after_proj
.
apply
(
c
)
return
c
,
c_skip
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
def
post_process
(
self
,
x
,
y
,
c_gate_msa
,
pre_infer_out
):
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
x
=
super
().
post_process
(
x
,
y
,
c_gate_msa
,
pre_infer_out
)
if
self
.
infer_state
==
"base"
and
self
.
block_idx
in
self
.
vace_blocks_mapping
:
if
self
.
infer_state
==
"base"
and
self
.
block_idx
in
self
.
vace_blocks_mapping
:
hint_idx
=
self
.
vace_blocks_mapping
[
self
.
block_idx
]
hint_idx
=
self
.
vace_blocks_mapping
[
self
.
block_idx
]
x
=
x
+
pre_infer_out
.
adapter_output
[
"hints"
][
hint_idx
]
*
pre_infer_out
.
adapter_output
.
get
(
"context_scale"
,
1.0
)
x
=
x
+
pre_infer_out
.
adapter_output
[
"hints"
][
hint_idx
]
*
pre_infer_out
.
adapter_output
.
get
(
"context_scale"
,
1.0
)
return
x
return
x
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
abeb9bc8
...
@@ -27,7 +27,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -27,7 +27,7 @@ class WanTransformerWeights(WeightModule):
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
#
post
blocks weights
#
non
blocks weights
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
]())
self
.
register_parameter
(
"norm"
,
LN_WEIGHT_REGISTER
[
"Default"
]())
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
...
@@ -67,15 +67,6 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -67,15 +67,6 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
compute_phases
=
WeightModuleList
(
self
.
compute_phases
=
WeightModuleList
(
[
[
WanModulation
(
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
WanSelfAttention
(
WanSelfAttention
(
block_index
,
block_index
,
block_prefix
,
block_prefix
,
...
@@ -109,7 +100,7 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -109,7 +100,7 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
add_module
(
"compute_phases"
,
self
.
compute_phases
)
self
.
add_module
(
"compute_phases"
,
self
.
compute_phases
)
class
Wan
Modula
tion
(
WeightModule
):
class
Wan
SelfAtten
tion
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
...
@@ -131,20 +122,6 @@ class WanModulation(WeightModule):
...
@@ -131,20 +122,6 @@ class WanModulation(WeightModule):
),
),
)
)
class
WanSelfAttention
(
WeightModule
):
def
__init__
(
self
,
block_index
,
block_prefix
,
task
,
mm_type
,
config
,
lazy_load
,
lazy_load_file
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
mm_type
=
mm_type
self
.
task
=
task
self
.
config
=
config
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
sparge
=
config
.
get
(
"sparge"
,
False
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
add_module
(
self
.
add_module
(
"norm1"
,
"norm1"
,
LN_WEIGHT_REGISTER
[
"Default"
](),
LN_WEIGHT_REGISTER
[
"Default"
](),
...
...
lightx2v/models/networks/wan/weights/vace/transformer_weights.py
View file @
abeb9bc8
...
@@ -9,8 +9,6 @@ from lightx2v.utils.registry_factory import (
...
@@ -9,8 +9,6 @@ from lightx2v.utils.registry_factory import (
)
)
# "vace_layers": [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28],
# {0: 0, 2: 1, 4: 2, 6: 3, 8: 4, 10: 5, 12: 6, 14: 7, 16: 8, 18: 9, 20: 10, 22: 11, 24: 12, 26: 13, 28: 14}
class
WanVaceTransformerWeights
(
WanTransformerWeights
):
class
WanVaceTransformerWeights
(
WanTransformerWeights
):
def
__init__
(
self
,
config
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
super
().
__init__
(
config
)
...
@@ -44,7 +42,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
...
@@ -44,7 +42,7 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
block_prefix
):
def
__init__
(
self
,
base_block_idx
,
block_index
,
task
,
mm_type
,
config
,
block_prefix
):
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
block_prefix
)
super
().
__init__
(
block_index
,
task
,
mm_type
,
config
,
block_prefix
)
if
base_block_idx
==
0
:
if
base_block_idx
==
0
:
self
.
add_module
(
self
.
compute_phases
[
0
].
add_module
(
"before_proj"
,
"before_proj"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.before_proj.weight"
,
...
@@ -53,7 +51,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
...
@@ -53,7 +51,8 @@ class WanVaceTransformerAttentionBlock(WanTransformerAttentionBlock):
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
)
)
self
.
add_module
(
self
.
compute_phases
[
-
1
].
add_module
(
"after_proj"
,
"after_proj"
,
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.after_proj.weight"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment