Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f3b4ba24
Unverified
Commit
f3b4ba24
authored
Dec 08, 2025
by
Gu Shiqiao
Committed by
GitHub
Dec 08, 2025
Browse files
[Recon] Reconstruct disk-cpu-cuda offload system (#578)
parent
67d6c6c1
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
512 additions
and
213 deletions
+512
-213
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+70
-3
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+73
-45
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+15
-7
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+11
-6
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+11
-6
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
+12
-2
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+10
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+14
-17
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+20
-8
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+27
-25
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+13
-1
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+90
-54
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+68
-35
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+76
-2
No files found.
lightx2v/common/offload/manager.py
View file @
f3b4ba24
from
concurrent.futures
import
ThreadPoolExecutor
import
torch
from
loguru
import
logger
from
packaging.version
import
parse
from
tqdm
import
tqdm
from
lightx2v.utils.profiler
import
ExcludedProfilingContext
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
...
...
@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object):
self
.
offload_granularity
=
offload_granularity
self
.
init_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
need_init_first_buffer
=
True
self
.
lazy_load
=
False
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
if
AI_DEVICE
==
"cuda"
and
torch_version
>=
parse
(
"2.7"
):
self
.
cuda_load_stream
=
torch_device_module
.
Stream
(
priority
=
1
)
...
...
@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object):
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
init_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
]
[
0
]
.
state_dict
(),
0
,
adapter_block_idx
)
else
:
if
self
.
offload_granularity
==
"block"
:
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
...
...
@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object):
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cpu_buffers
[
phase_idx
].
load_state_dict_from_disk
(
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
0
][
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
else
:
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
...
...
@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object):
def
swap_phases
(
self
):
self
.
cuda_load_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
@
ExcludedProfilingContext
(
"🔥 warm_up_cpu_buffers"
)
def
warm_up_cpu_buffers
(
self
,
blocks_num
):
logger
.
info
(
"🔥 Warming up cpu buffers..."
)
for
i
in
tqdm
(
range
(
blocks_num
)):
for
phase
in
self
.
cpu_buffers
[
0
]:
phase
.
load_state_dict_from_disk
(
i
,
None
)
for
phase
in
self
.
cpu_buffers
[
1
]:
phase
.
load_state_dict_from_disk
(
i
,
None
)
for
phase
in
self
.
cpu_buffers
[
0
]:
phase
.
load_state_dict_from_disk
(
0
,
None
)
for
phase
in
self
.
cpu_buffers
[
1
]:
phase
.
load_state_dict_from_disk
(
1
,
None
)
logger
.
info
(
"✅ CPU buffers warm-up completed."
)
def
init_lazy_load
(
self
,
num_workers
=
6
):
self
.
lazy_load
=
True
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
num_workers
)
self
.
prefetch_futures
=
[]
self
.
prefetch_block_idx
=
-
1
def
start_prefetch_block
(
self
,
block_idx
,
adapter_block_idx
=
None
):
self
.
prefetch_block_idx
=
block_idx
self
.
prefetch_futures
=
[]
for
phase
in
self
.
cpu_buffers
[
1
]:
future
=
self
.
executor
.
submit
(
phase
.
load_state_dict_from_disk
,
block_idx
,
adapter_block_idx
)
self
.
prefetch_futures
.
append
(
future
)
def
swap_cpu_buffers
(
self
):
import
time
wait_start
=
time
.
time
()
already_done
=
all
(
f
.
done
()
for
f
in
self
.
prefetch_futures
)
for
f
in
self
.
prefetch_futures
:
f
.
result
()
wait_time
=
time
.
time
()
-
wait_start
logger
.
debug
(
f
"[Prefetch] block
{
self
.
prefetch_block_idx
}
: wait=
{
wait_time
:.
3
f
}
s, already_done=
{
already_done
}
"
)
self
.
cpu_buffers
=
[
self
.
cpu_buffers
[
1
],
self
.
cpu_buffers
[
0
]]
def
shutdown
(
self
,
wait
=
True
):
"""Shutdown the thread pool executor and wait for all pending tasks to complete."""
if
hasattr
(
self
,
"executor"
)
and
self
.
executor
is
not
None
:
# Wait for all pending futures to complete before shutting down
if
hasattr
(
self
,
"prefetch_futures"
):
for
f
in
self
.
prefetch_futures
:
try
:
if
not
f
.
done
():
f
.
result
()
except
Exception
:
pass
self
.
executor
.
shutdown
(
wait
=
wait
)
self
.
executor
=
None
logger
.
debug
(
"ThreadPoolExecutor shut down successfully."
)
def
__del__
(
self
):
"""Cleanup method to ensure executor is shut down when object is destroyed."""
try
:
if
hasattr
(
self
,
"executor"
)
and
self
.
executor
is
not
None
:
self
.
executor
.
shutdown
(
wait
=
False
)
except
Exception
:
pass
lightx2v/common/ops/mm/mm_weight.py
View file @
f3b4ba24
import
os
import
re
from
abc
import
ABCMeta
,
abstractmethod
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.ggml_tensor
import
GGMLTensor
...
...
@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate):
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
if
self
.
lazy_load
:
return
self
.
lazy_load_file
.
get_tensor
(
source_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
source_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
return
lazy_load_file
.
get_tensor
(
source_name
)
return
weight_dict
[
source_name
]
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
...
...
@@ -145,15 +149,18 @@ class MMWeight(MMWeightTemplate):
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
lazy_load
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
...
...
@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate):
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
...
...
@@ -208,9 +211,16 @@ class MMWeight(MMWeightTemplate):
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
...
...
@@ -283,9 +293,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
_load_default_tensors
(
weight_dict
)
def
_load_cuda_buffers
(
self
,
weight_dict
):
source
=
self
.
lazy_load_file
if
self
.
lazy_load
else
weight_dict
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
if
self
.
lazy_load
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
else
:
source
=
weight_dict
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
def
_get_cuda_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
...
...
@@ -318,30 +334,38 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_dtype
=
torch
.
float
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
else
:
weight_tensor
=
source
[
self
.
weight_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
return
pin_weight
,
pin_scale
def
_get_cpu_pin_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
return
None
if
is_lazy
:
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
if
not
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
if
not
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
bias_tensor
=
source
[
self
.
bias_name
]
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
def
_create_pin_tensor
(
self
,
tensor
,
dtype
=
None
):
dtype
=
dtype
or
tensor
.
dtype
...
...
@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
if
self
.
weight_need_transpose
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
weight_scale_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
...
...
@@ -661,9 +674,24 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
if
self
.
weight_need_transpose
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
weight_scale_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_scale_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
MM_WEIGHT_REGISTER
(
"fp8-vllm"
)
...
...
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
f3b4ba24
import
os
import
re
from
abc
import
ABCMeta
,
abstractmethod
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
...
...
@@ -53,9 +55,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
if
name
is
None
:
return
None
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
name
]
return
tensor
...
...
@@ -151,24 +155,28 @@ class LNWeightTemplate(metaclass=ABCMeta):
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
weight_name
is
not
None
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
f3b4ba24
import
os
import
re
from
abc
import
ABCMeta
,
abstractmethod
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
...
...
@@ -46,9 +48,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
self
.
weight_name
]
return
tensor
...
...
@@ -107,9 +111,10 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
f3b4ba24
import
os
import
re
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
...
...
@@ -39,9 +41,11 @@ class DefaultTensor:
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
tensor_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
self
.
tensor_name
]
return
tensor
...
...
@@ -92,7 +96,8 @@ class DefaultTensor:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
del
tensor
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
View file @
f3b4ba24
import
torch
from
einops
import
rearrange
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
loguru
import
logger
try
:
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
except
ImportError
:
flash_attn_varlen_qkvpacked_func
=
None
logger
.
info
(
"flash_attn_varlen_qkvpacked_func not available"
)
try
:
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
except
ImportError
:
pad_input
=
None
unpad_input
=
None
logger
.
info
(
"flash_attn.bert_padding not available"
)
try
:
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
except
ImportError
:
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
f3b4ba24
...
...
@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if
offload_granularity
!=
"model"
:
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
and
offload_granularity
==
"phase"
:
self
.
offload_manager
.
init_lazy_load
(
num_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
4
))
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
...
...
@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
if
self
.
lazy_load
:
next_prefetch
=
(
block_idx
+
1
)
%
len
(
blocks
)
self
.
offload_manager
.
start_prefetch_block
(
next_prefetch
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
)
if
self
.
clean_cuda_cache
:
del
(
...
...
@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self
.
offload_manager
.
init_first_buffer
(
blocks
)
next_block_idx
=
(
block_idx
+
1
)
%
len
(
blocks
)
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
if
self
.
lazy_load
:
if
phase_idx
==
self
.
phases_num
-
1
:
self
.
offload_manager
.
swap_cpu_buffers
()
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
f3b4ba24
...
...
@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
if
self
.
clean_cuda_cache
:
del
norm1_out
,
norm1_weight
,
norm1_bias
del
norm1_out
,
shift_msa
,
scale_msa
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
...
...
@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y
=
phase
.
ffn_0
.
apply
(
norm2_out
)
if
self
.
clean_cuda_cache
:
del
norm2_out
,
x
,
norm2_weight
,
norm2_bias
del
norm2_out
,
x
torch
.
cuda
.
empty_cache
()
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
if
self
.
clean_cuda_cache
:
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
f3b4ba24
...
...
@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk(
rope_func
,
):
seq_len
=
cos_sin_cache
.
size
(
0
)
x_q
=
torch
.
empty_like
(
xq
)
x_k
=
torch
.
empty_like
(
xk
)
xq_output_chunks
=
[]
xk_output_chunks
=
[]
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
end
=
min
(
start
+
chunk_size
,
seq_len
)
xq_chunk
=
xq
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
xq_chunk
,
xk_chunk
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
xq_output_chunks
.
append
(
xq_chunk
)
xk_output_chunks
.
append
(
xk_chunk
)
torch
.
cuda
.
empty_cache
()
x_q
=
torch
.
cat
(
xq_output_chunks
,
dim
=
0
)
del
xq_output_chunks
torch
.
cuda
.
empty_cache
()
x_k
=
torch
.
cat
(
xk_output_chunks
,
dim
=
0
)
del
xk_output_chunks
torch
.
cuda
.
empty_cache
()
return
x_q
.
to
(
GET_DTYPE
()),
x_k
.
to
(
GET_DTYPE
())
xq_chunk_out
,
xk_chunk_out
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
x_q
[
start
:
end
].
copy_
(
xq_chunk_out
,
non_blocking
=
True
)
x_k
[
start
:
end
].
copy_
(
xk_chunk_out
,
non_blocking
=
True
)
del
xq_chunk_out
,
xk_chunk_out
target_dtype
=
GET_DTYPE
()
if
x_q
.
dtype
!=
target_dtype
:
x_q
=
x_q
.
to
(
target_dtype
)
if
x_k
.
dtype
!=
target_dtype
:
x_k
=
x_k
.
to
(
target_dtype
)
return
x_q
,
x_k
def
apply_wan_rope_with_flashinfer
(
...
...
lightx2v/models/networks/wan/model.py
View file @
f3b4ba24
...
...
@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_files
=
[
safetensors_path
]
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_files
[
0
]
self
.
lazy_load_path
=
safetensors_path
non_block_file
=
os
.
path
.
join
(
safetensors_path
,
"non_block.safetensors"
)
if
os
.
path
.
exists
(
non_block_file
):
safetensors_files
=
[
non_block_file
]
else
:
raise
ValueError
(
f
"Non-block file not found in
{
safetensors_path
}
"
)
weight_dict
=
{}
for
file_path
in
safetensors_files
:
...
...
@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
):
safetensors_path
=
self
.
config
[
"dit_quantized_ckpt"
]
else
:
...
...
@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_files
[
0
]
self
.
lazy_load_path
=
safetensors_path
non_block_file
=
os
.
path
.
join
(
safetensors_path
,
"non_block.safetensors"
)
if
os
.
path
.
exists
(
non_block_file
):
safetensors_files
=
[
non_block_file
]
else
:
raise
ValueError
(
f
"Non-block file not found in
{
safetensors_path
}
, Please check the lazy load model path"
)
weight_dict
=
{}
for
safetensor_path
in
safetensors_files
:
...
...
@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin):
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
self
.
_init_offload_manager
()
def
_init_offload_manager
(
self
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
if
self
.
config
.
get
(
"warm_up_cpu_buffers"
,
False
):
self
.
transformer_infer
.
offload_manager
.
warm_up_cpu_buffers
(
self
.
transformer_weights
.
blocks_num
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
f3b4ba24
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.utils.registry_factory
import
(
ATTN_WEIGHT_REGISTER
,
...
...
@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule):
if
config
.
get
(
"do_mm_calib"
,
False
):
self
.
mm_type
=
"Calib"
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
not
self
.
lazy_load
:
self
.
lazy_load_file
=
None
else
:
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
self
.
blocks
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
...
...
@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
for
i
in
range
(
self
.
blocks_num
)
]
)
self
.
register_offload_buffers
(
config
)
self
.
register_offload_buffers
(
config
,
lazy_load_path
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
# non blocks weights
...
...
@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule):
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
register_offload_buffers
(
self
,
config
):
def
register_offload_buffers
(
self
,
config
,
lazy_load_path
):
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
...
...
@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
...
...
@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
...
...
@@ -104,22 +98,27 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
).
compute_phases
self
.
add_module
(
"offload_phase_cuda_buffers"
,
self
.
offload_phase_cuda_buffers
)
self
.
offload_block_cuda_buffers
=
None
if
self
.
lazy_load
:
self
.
offload_phase_cpu_buffers
=
WanTransformerAttentionBlock
(
block_index
=
0
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_file
=
self
.
lazy_load_file
,
).
compute_phases
self
.
offload_phase_cpu_buffers
=
WeightModuleList
(
[
WanTransformerAttentionBlock
(
block_index
=
i
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load_path
=
lazy_load_path
,
).
compute_phases
for
i
in
range
(
2
)
]
)
self
.
add_module
(
"offload_phase_cpu_buffers"
,
self
.
offload_phase_cpu_buffers
)
self
.
offload_block_cpu_buffers
=
None
...
...
@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule):
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
lazy_load
=
False
,
lazy_load_
file
=
None
,
lazy_load_
path
=
None
,
):
super
().
__init__
()
self
.
block_index
=
block_index
...
...
@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
if
self
.
lazy_load
:
self
.
lazy_load_file
=
lazy_load_path
else
:
self
.
lazy_load_file
=
None
self
.
compute_phases
=
WeightModuleList
(
[
...
...
lightx2v/models/runners/default_runner.py
View file @
f3b4ba24
...
...
@@ -185,7 +185,19 @@ class DefaultRunner(BaseRunner):
del
self
.
inputs
self
.
input_info
=
None
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
model
if
hasattr
(
self
.
model
,
"model"
)
and
len
(
self
.
model
.
model
)
==
2
:
# MultiModelStruct
for
model
in
self
.
model
.
model
:
if
hasattr
(
model
.
transformer_infer
,
"offload_manager"
):
del
model
.
transformer_infer
.
offload_manager
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
del
model
else
:
if
hasattr
(
self
.
model
.
transformer_infer
,
"offload_manager"
):
del
self
.
model
.
transformer_infer
.
offload_manager
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
del
self
.
model
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
torch
.
save
(
CALIB
,
calib_path
)
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
f3b4ba24
...
...
@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct):
self
.
to_cuda
(
model_index
=
1
)
self
.
cur_model_index
=
1
def
infer
(
self
,
inputs
):
self
.
get_current_model_index
()
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
model
[
self
.
cur_model_index
]
is
not
None
:
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
cur_model_index
==
0
:
high_noise_model
=
WanDistillModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
0
]
=
high_noise_model
self
.
model
[
0
].
infer
(
inputs
)
elif
self
.
cur_model_index
==
1
:
low_noise_model
=
WanDistillModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
1
]
=
low_noise_model
self
.
model
[
1
].
infer
(
inputs
)
@
RUNNER_REGISTER
(
"wan2.2_moe_distill"
)
class
Wan22MoeDistillRunner
(
WanDistillRunner
):
...
...
@@ -101,61 +130,68 @@ class Wan22MoeDistillRunner(WanDistillRunner):
raise
FileNotFoundError
(
f
"Low Noise Model does not find"
)
def
load_transformer
(
self
):
use_high_lora
,
use_low_lora
=
False
,
False
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
use_high_lora
=
True
elif
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
use_low_lora
=
True
if
use_high_lora
:
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
high_lora_wrapper
.
load_lora
(
lora_path
)
high_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"High noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
use_high_lora
,
use_low_lora
=
False
,
False
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
use_high_lora
=
True
elif
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
use_low_lora
=
True
if
use_high_lora
:
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
high_lora_wrapper
.
load_lora
(
lora_path
)
high_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"High noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
high_noise_model
=
WanDistillModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
if
use_low_lora
:
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
low_lora_wrapper
.
load_lora
(
lora_path
)
low_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Low noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
low_noise_model
=
WanDistillModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
else
:
high_noise_model
=
WanDistillModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
if
use_low_lora
:
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
low_lora_wrapper
.
load_lora
(
lora_path
)
low_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Low noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
low_noise_model
=
WanDistillModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
model_struct
=
MultiDistillModelStruct
([
None
,
None
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
model_struct
.
low_noise_model_path
=
self
.
low_noise_model_path
model_struct
.
high_noise_model_path
=
self
.
high_noise_model_path
model_struct
.
init_device
=
self
.
init_device
return
model_struct
def
init_scheduler
(
self
):
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
f3b4ba24
...
...
@@ -468,11 +468,37 @@ class MultiModelStruct:
def
set_scheduler
(
self
,
shared_scheduler
):
self
.
scheduler
=
shared_scheduler
for
model
in
self
.
model
:
model
.
set_scheduler
(
shared_scheduler
)
if
model
is
not
None
:
model
.
set_scheduler
(
shared_scheduler
)
def
infer
(
self
,
inputs
):
self
.
get_current_model_index
()
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
model
[
self
.
cur_model_index
]
is
not
None
:
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
cur_model_index
==
0
:
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
0
]
=
high_noise_model
self
.
model
[
0
].
infer
(
inputs
)
elif
self
.
cur_model_index
==
1
:
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
1
]
=
low_noise_model
self
.
model
[
1
].
infer
(
inputs
)
@
ProfilingContext4DebugL2
(
"Swtich models in infer_main costs"
)
def
get_current_model_index
(
self
):
...
...
@@ -526,40 +552,47 @@ class Wan22MoeRunner(WanRunner):
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
base_name
=
os
.
path
.
basename
(
lora_path
)
if
base_name
.
startswith
(
"high"
):
lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
elif
base_name
.
startswith
(
"low"
):
lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary"
])
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
base_name
=
os
.
path
.
basename
(
lora_path
)
if
base_name
.
startswith
(
"high"
):
lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
elif
base_name
.
startswith
(
"low"
):
lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary"
])
else
:
model_struct
=
MultiModelStruct
([
None
,
None
],
self
.
config
,
self
.
config
[
"boundary"
])
model_struct
.
low_noise_model_path
=
self
.
low_noise_model_path
model_struct
.
high_noise_model_path
=
self
.
high_noise_model_path
model_struct
.
init_device
=
self
.
init_device
return
model_struct
@
RUNNER_REGISTER
(
"wan2.2"
)
...
...
lightx2v/utils/profiler.py
View file @
f3b4ba24
import
asyncio
import
threading
import
time
from
functools
import
wraps
...
...
@@ -10,6 +11,13 @@ from lightx2v.utils.envs import *
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
_excluded_time_local
=
threading
.
local
()
def
_get_excluded_time_stack
():
if
not
hasattr
(
_excluded_time_local
,
"stack"
):
_excluded_time_local
.
stack
=
[]
return
_excluded_time_local
.
stack
class
_ProfilingContext
:
...
...
@@ -32,11 +40,14 @@ class _ProfilingContext:
def
__enter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
_get_excluded_time_stack
().
append
(
0.0
)
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
total_elapsed
=
time
.
perf_counter
()
-
self
.
start_time
excluded
=
_get_excluded_time_stack
().
pop
()
elapsed
=
total_elapsed
-
excluded
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
...
...
@@ -49,11 +60,14 @@ class _ProfilingContext:
async
def
__aenter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
_get_excluded_time_stack
().
append
(
0.0
)
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
total_elapsed
=
time
.
perf_counter
()
-
self
.
start_time
excluded
=
_get_excluded_time_stack
().
pop
()
elapsed
=
total_elapsed
-
excluded
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
...
...
@@ -103,6 +117,65 @@ class _NullContext:
return
func
class
_ExcludedProfilingContext
:
"""用于标记应该从外层 profiling 中排除的时间段"""
def
__init__
(
self
,
name
=
None
):
self
.
name
=
name
if
dist
.
is_initialized
():
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
else
:
self
.
rank_info
=
"Single GPU"
def
__enter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
stack
=
_get_excluded_time_stack
()
for
i
in
range
(
len
(
stack
)):
stack
[
i
]
+=
elapsed
if
self
.
name
and
CHECK_PROFILING_DEBUG_LEVEL
(
1
):
logger
.
info
(
f
"[Profile-Excluded]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds (excluded from outer profiling)"
)
return
False
async
def
__aenter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
stack
=
_get_excluded_time_stack
()
for
i
in
range
(
len
(
stack
)):
stack
[
i
]
+=
elapsed
if
self
.
name
and
CHECK_PROFILING_DEBUG_LEVEL
(
1
):
logger
.
info
(
f
"[Profile-Excluded]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds (excluded from outer profiling)"
)
return
False
def
__call__
(
self
,
func
):
if
asyncio
.
iscoroutinefunction
(
func
):
@
wraps
(
func
)
async
def
async_wrapper
(
*
args
,
**
kwargs
):
async
with
self
:
return
await
func
(
*
args
,
**
kwargs
)
return
async_wrapper
else
:
@
wraps
(
func
)
def
sync_wrapper
(
*
args
,
**
kwargs
):
with
self
:
return
func
(
*
args
,
**
kwargs
)
return
sync_wrapper
class
_ProfilingContextL1
(
_ProfilingContext
):
"""Level 1 profiling context with Level1_Log prefix."""
...
...
@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De
"""
ProfilingContext4DebugL1
=
_ProfilingContextL1
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
# if user >= 1, enable profiling
ProfilingContext4DebugL2
=
_ProfilingContextL2
if
CHECK_PROFILING_DEBUG_LEVEL
(
2
)
else
_NullContext
# if user >= 2, enable profiling
ExcludedProfilingContext
=
_ExcludedProfilingContext
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment