Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f3b4ba24
Unverified
Commit
f3b4ba24
authored
Dec 08, 2025
by
Gu Shiqiao
Committed by
GitHub
Dec 08, 2025
Browse files
[Recon] Reconstruct disk-cpu-cuda offload system (#578)
parent
67d6c6c1
Changes
15
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
512 additions
and
213 deletions
+512
-213
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+70
-3
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+73
-45
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+15
-7
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+11
-6
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+11
-6
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
+12
-2
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+10
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+14
-17
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+20
-8
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+27
-25
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+13
-1
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+90
-54
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+68
-35
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+76
-2
No files found.
lightx2v/common/offload/manager.py
View file @
f3b4ba24
from
concurrent.futures
import
ThreadPoolExecutor
import
torch
import
torch
from
loguru
import
logger
from
packaging.version
import
parse
from
packaging.version
import
parse
from
tqdm
import
tqdm
from
lightx2v.utils.profiler
import
ExcludedProfilingContext
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
...
@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object):
...
@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object):
self
.
offload_granularity
=
offload_granularity
self
.
offload_granularity
=
offload_granularity
self
.
init_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
init_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
need_init_first_buffer
=
True
self
.
need_init_first_buffer
=
True
self
.
lazy_load
=
False
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
if
AI_DEVICE
==
"cuda"
and
torch_version
>=
parse
(
"2.7"
):
if
AI_DEVICE
==
"cuda"
and
torch_version
>=
parse
(
"2.7"
):
self
.
cuda_load_stream
=
torch_device_module
.
Stream
(
priority
=
1
)
self
.
cuda_load_stream
=
torch_device_module
.
Stream
(
priority
=
1
)
...
@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object):
...
@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object):
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
init_stream
):
with
torch_device_module
.
stream
(
self
.
init_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
]
[
0
]
.
state_dict
(),
0
,
adapter_block_idx
)
else
:
else
:
if
self
.
offload_granularity
==
"block"
:
if
self
.
offload_granularity
==
"block"
:
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
...
@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object):
...
@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object):
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cpu_buffers
[
phase_idx
].
load_state_dict_from_disk
(
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
0
][
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
else
:
else
:
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
...
@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object):
...
@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object):
def
swap_phases
(
self
):
def
swap_phases
(
self
):
self
.
cuda_load_stream
.
synchronize
()
self
.
cuda_load_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
@
ExcludedProfilingContext
(
"🔥 warm_up_cpu_buffers"
)
def
warm_up_cpu_buffers
(
self
,
blocks_num
):
logger
.
info
(
"🔥 Warming up cpu buffers..."
)
for
i
in
tqdm
(
range
(
blocks_num
)):
for
phase
in
self
.
cpu_buffers
[
0
]:
phase
.
load_state_dict_from_disk
(
i
,
None
)
for
phase
in
self
.
cpu_buffers
[
1
]:
phase
.
load_state_dict_from_disk
(
i
,
None
)
for
phase
in
self
.
cpu_buffers
[
0
]:
phase
.
load_state_dict_from_disk
(
0
,
None
)
for
phase
in
self
.
cpu_buffers
[
1
]:
phase
.
load_state_dict_from_disk
(
1
,
None
)
logger
.
info
(
"✅ CPU buffers warm-up completed."
)
def
init_lazy_load
(
self
,
num_workers
=
6
):
self
.
lazy_load
=
True
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
num_workers
)
self
.
prefetch_futures
=
[]
self
.
prefetch_block_idx
=
-
1
def
start_prefetch_block
(
self
,
block_idx
,
adapter_block_idx
=
None
):
self
.
prefetch_block_idx
=
block_idx
self
.
prefetch_futures
=
[]
for
phase
in
self
.
cpu_buffers
[
1
]:
future
=
self
.
executor
.
submit
(
phase
.
load_state_dict_from_disk
,
block_idx
,
adapter_block_idx
)
self
.
prefetch_futures
.
append
(
future
)
def
swap_cpu_buffers
(
self
):
import
time
wait_start
=
time
.
time
()
already_done
=
all
(
f
.
done
()
for
f
in
self
.
prefetch_futures
)
for
f
in
self
.
prefetch_futures
:
f
.
result
()
wait_time
=
time
.
time
()
-
wait_start
logger
.
debug
(
f
"[Prefetch] block
{
self
.
prefetch_block_idx
}
: wait=
{
wait_time
:.
3
f
}
s, already_done=
{
already_done
}
"
)
self
.
cpu_buffers
=
[
self
.
cpu_buffers
[
1
],
self
.
cpu_buffers
[
0
]]
def
shutdown
(
self
,
wait
=
True
):
"""Shutdown the thread pool executor and wait for all pending tasks to complete."""
if
hasattr
(
self
,
"executor"
)
and
self
.
executor
is
not
None
:
# Wait for all pending futures to complete before shutting down
if
hasattr
(
self
,
"prefetch_futures"
):
for
f
in
self
.
prefetch_futures
:
try
:
if
not
f
.
done
():
f
.
result
()
except
Exception
:
pass
self
.
executor
.
shutdown
(
wait
=
wait
)
self
.
executor
=
None
logger
.
debug
(
"ThreadPoolExecutor shut down successfully."
)
def
__del__
(
self
):
"""Cleanup method to ensure executor is shut down when object is destroyed."""
try
:
if
hasattr
(
self
,
"executor"
)
and
self
.
executor
is
not
None
:
self
.
executor
.
shutdown
(
wait
=
False
)
except
Exception
:
pass
lightx2v/common/ops/mm/mm_weight.py
View file @
f3b4ba24
import
os
import
re
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.ggml_tensor
import
GGMLTensor
from
lightx2v.utils.ggml_tensor
import
GGMLTensor
...
@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate):
...
@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate):
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
if
self
.
lazy_load
:
if
self
.
lazy_load
:
return
self
.
lazy_load_file
.
get_tensor
(
source_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
source_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
return
lazy_load_file
.
get_tensor
(
source_name
)
return
weight_dict
[
source_name
]
return
weight_dict
[
source_name
]
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
...
@@ -145,15 +149,18 @@ class MMWeight(MMWeightTemplate):
...
@@ -145,15 +149,18 @@ class MMWeight(MMWeightTemplate):
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
def
_load_cpu_pin_buffers
(
self
):
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
self
.
lazy_load
:
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
pin_bias
=
None
def
_load_default_tensors
(
self
,
weight_dict
):
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
not
self
.
lazy_load
:
...
@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate):
...
@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate):
else
:
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
assert
adapter_block_index
is
not
None
...
@@ -208,9 +211,16 @@ class MMWeight(MMWeightTemplate):
...
@@ -208,9 +211,16 @@ class MMWeight(MMWeightTemplate):
else
:
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
del
bias_tensor
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
...
@@ -283,9 +293,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -283,9 +293,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
_load_default_tensors
(
weight_dict
)
self
.
_load_default_tensors
(
weight_dict
)
def
_load_cuda_buffers
(
self
,
weight_dict
):
def
_load_cuda_buffers
(
self
,
weight_dict
):
source
=
self
.
lazy_load_file
if
self
.
lazy_load
else
weight_dict
if
self
.
lazy_load
:
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
else
:
source
=
weight_dict
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
def
_get_cuda_tensor_pair
(
self
,
source
,
is_lazy
):
def
_get_cuda_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
if
is_lazy
:
...
@@ -318,30 +334,38 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -318,30 +334,38 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
if
is_lazy
:
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
scale_dtype
=
torch
.
float
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
else
:
else
:
weight_tensor
=
source
[
self
.
weight_name
]
weight_tensor
=
source
[
self
.
weight_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_dtype
=
torch
.
float
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
return
pin_weight
,
pin_scale
return
pin_weight
,
pin_scale
def
_get_cpu_pin_bias_tensor
(
self
,
source
,
is_lazy
):
def
_get_cpu_pin_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
if
self
.
bias_name
is
None
:
return
None
return
None
if
is_lazy
:
if
is_lazy
:
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
if
not
self
.
bias_force_fp32
:
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
if
not
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
else
:
bias_tensor
=
source
[
self
.
bias_name
]
bias_tensor
=
source
[
self
.
bias_name
]
if
self
.
bias_force_fp32
:
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
def
_create_pin_tensor
(
self
,
tensor
,
dtype
=
None
):
def
_create_pin_tensor
(
self
,
tensor
,
dtype
=
None
):
dtype
=
dtype
or
tensor
.
dtype
dtype
=
dtype
or
tensor
.
dtype
...
@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
if
self
.
weight_need_transpose
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
weight_scale_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
assert
adapter_block_index
is
not
None
...
@@ -661,9 +674,24 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -661,9 +674,24 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else
:
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
del
bias_tensor
if
self
.
weight_need_transpose
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
weight_scale_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_scale_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
MM_WEIGHT_REGISTER
(
"fp8-vllm"
)
@
MM_WEIGHT_REGISTER
(
"fp8-vllm"
)
...
...
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
f3b4ba24
import
os
import
re
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
...
@@ -53,9 +55,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -53,9 +55,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
if
name
is
None
:
if
name
is
None
:
return
None
return
None
if
self
.
lazy_load
:
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
if
use_infer_dtype
:
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
tensor
=
lazy_load_file
.
get_tensor
(
name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
tensor
=
weight_dict
[
name
]
tensor
=
weight_dict
[
name
]
return
tensor
return
tensor
...
@@ -151,24 +155,28 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -151,24 +155,28 @@ class LNWeightTemplate(metaclass=ABCMeta):
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
weight_name
is
not
None
:
if
self
.
weight_name
is
not
None
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
self
.
pin_bias
.
copy_
(
bias_tensor
)
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
del
bias_tensor
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
f3b4ba24
import
os
import
re
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
...
@@ -46,9 +48,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -46,9 +48,11 @@ class RMSWeightTemplate(metaclass=ABCMeta):
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
if
use_infer_dtype
:
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
tensor
=
weight_dict
[
self
.
weight_name
]
tensor
=
weight_dict
[
self
.
weight_name
]
return
tensor
return
tensor
...
@@ -107,9 +111,10 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -107,9 +111,10 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
del
weight_tensor
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
f3b4ba24
import
os
import
re
import
re
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
...
@@ -39,9 +41,11 @@ class DefaultTensor:
...
@@ -39,9 +41,11 @@ class DefaultTensor:
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
tensor_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
if
use_infer_dtype
:
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
tensor
=
weight_dict
[
self
.
tensor_name
]
tensor
=
weight_dict
[
self
.
tensor_name
]
return
tensor
return
tensor
...
@@ -92,7 +96,8 @@ class DefaultTensor:
...
@@ -92,7 +96,8 @@ class DefaultTensor:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
else
:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
del
tensor
del
tensor
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
View file @
f3b4ba24
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
loguru
import
logger
from
loguru
import
logger
try
:
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
except
ImportError
:
flash_attn_varlen_qkvpacked_func
=
None
logger
.
info
(
"flash_attn_varlen_qkvpacked_func not available"
)
try
:
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
except
ImportError
:
pad_input
=
None
unpad_input
=
None
logger
.
info
(
"flash_attn.bert_padding not available"
)
try
:
try
:
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
except
ImportError
:
except
ImportError
:
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
f3b4ba24
...
@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if
offload_granularity
!=
"model"
:
if
offload_granularity
!=
"model"
:
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
and
offload_granularity
==
"phase"
:
self
.
offload_manager
.
init_lazy_load
(
num_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
4
))
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
...
@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
self
.
lazy_load
:
next_prefetch
=
(
block_idx
+
1
)
%
len
(
blocks
)
self
.
offload_manager
.
start_prefetch_block
(
next_prefetch
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
(
del
(
...
@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self
.
offload_manager
.
init_first_buffer
(
blocks
)
self
.
offload_manager
.
init_first_buffer
(
blocks
)
next_block_idx
=
(
block_idx
+
1
)
%
len
(
blocks
)
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_block_idx
=
(
block_idx
+
1
)
%
len
(
blocks
)
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
if
self
.
lazy_load
:
if
phase_idx
==
self
.
phases_num
-
1
:
self
.
offload_manager
.
swap_cpu_buffers
()
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
f3b4ba24
...
@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
norm1_out
,
norm1_weight
,
norm1_bias
del
norm1_out
,
shift_msa
,
scale_msa
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
...
@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y
=
phase
.
ffn_0
.
apply
(
norm2_out
)
y
=
phase
.
ffn_0
.
apply
(
norm2_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
norm2_out
,
x
,
norm2_weight
,
norm2_bias
del
norm2_out
,
x
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
f3b4ba24
...
@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk(
...
@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk(
rope_func
,
rope_func
,
):
):
seq_len
=
cos_sin_cache
.
size
(
0
)
seq_len
=
cos_sin_cache
.
size
(
0
)
x_q
=
torch
.
empty_like
(
xq
)
x_k
=
torch
.
empty_like
(
xk
)
xq_output_chunks
=
[]
xk_output_chunks
=
[]
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
end
=
min
(
start
+
chunk_size
,
seq_len
)
end
=
min
(
start
+
chunk_size
,
seq_len
)
xq_chunk
=
xq
[
start
:
end
]
xq_chunk
=
xq
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
xq_chunk_out
,
xk_chunk_out
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
xq_chunk
,
xk_chunk
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
x_q
[
start
:
end
].
copy_
(
xq_chunk_out
,
non_blocking
=
True
)
xq_output_chunks
.
append
(
xq_chunk
)
x_k
[
start
:
end
].
copy_
(
xk_chunk_out
,
non_blocking
=
True
)
xk_output_chunks
.
append
(
xk_chunk
)
del
xq_chunk_out
,
xk_chunk_out
torch
.
cuda
.
empty_cache
()
target_dtype
=
GET_DTYPE
()
x_q
=
torch
.
cat
(
xq_output_chunks
,
dim
=
0
)
if
x_q
.
dtype
!=
target_dtype
:
del
xq_output_chunks
x_q
=
x_q
.
to
(
target_dtype
)
torch
.
cuda
.
empty_cache
()
if
x_k
.
dtype
!=
target_dtype
:
x_k
=
x_k
.
to
(
target_dtype
)
x_k
=
torch
.
cat
(
xk_output_chunks
,
dim
=
0
)
del
xk_output_chunks
return
x_q
,
x_k
torch
.
cuda
.
empty_cache
()
return
x_q
.
to
(
GET_DTYPE
()),
x_k
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_flashinfer
(
def
apply_wan_rope_with_flashinfer
(
...
...
lightx2v/models/networks/wan/model.py
View file @
f3b4ba24
...
@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin):
...
@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_files
=
[
safetensors_path
]
safetensors_files
=
[
safetensors_path
]
if
self
.
lazy_load
:
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_path
self
.
lazy_load_path
=
safetensors_files
[
0
]
non_block_file
=
os
.
path
.
join
(
safetensors_path
,
"non_block.safetensors"
)
if
os
.
path
.
exists
(
non_block_file
):
safetensors_files
=
[
non_block_file
]
else
:
raise
ValueError
(
f
"Non-block file not found in
{
safetensors_path
}
"
)
weight_dict
=
{}
weight_dict
=
{}
for
file_path
in
safetensors_files
:
for
file_path
in
safetensors_files
:
...
@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin):
...
@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
):
if
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
):
safetensors_path
=
self
.
config
[
"dit_quantized_ckpt"
]
safetensors_path
=
self
.
config
[
"dit_quantized_ckpt"
]
else
:
else
:
...
@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin):
...
@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
if
self
.
lazy_load
:
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_path
self
.
lazy_load_path
=
safetensors_files
[
0
]
non_block_file
=
os
.
path
.
join
(
safetensors_path
,
"non_block.safetensors"
)
if
os
.
path
.
exists
(
non_block_file
):
safetensors_files
=
[
non_block_file
]
else
:
raise
ValueError
(
f
"Non-block file not found in
{
safetensors_path
}
, Please check the lazy load model path"
)
weight_dict
=
{}
weight_dict
=
{}
for
safetensor_path
in
safetensors_files
:
for
safetensor_path
in
safetensors_files
:
...
@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin):
...
@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin):
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
self
.
_init_offload_manager
()
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
def
_init_offload_manager
(
self
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
if
self
.
config
.
get
(
"warm_up_cpu_buffers"
,
False
):
self
.
transformer_infer
.
offload_manager
.
warm_up_cpu_buffers
(
self
.
transformer_weights
.
blocks_num
)
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
f3b4ba24
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.utils.registry_factory
import
(
from
lightx2v.utils.registry_factory
import
(
ATTN_WEIGHT_REGISTER
,
ATTN_WEIGHT_REGISTER
,
...
@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule):
...
@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule):
if
config
.
get
(
"do_mm_calib"
,
False
):
if
config
.
get
(
"do_mm_calib"
,
False
):
self
.
mm_type
=
"Calib"
self
.
mm_type
=
"Calib"
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
not
self
.
lazy_load
:
self
.
lazy_load_file
=
None
else
:
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
self
.
blocks
=
WeightModuleList
(
self
.
blocks
=
WeightModuleList
(
[
[
WanTransformerAttentionBlock
(
WanTransformerAttentionBlock
(
...
@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule):
...
@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
)
for
i
in
range
(
self
.
blocks_num
)
for
i
in
range
(
self
.
blocks_num
)
]
]
)
)
self
.
register_offload_buffers
(
config
)
self
.
register_offload_buffers
(
config
,
lazy_load_path
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
# non blocks weights
# non blocks weights
...
@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule):
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
register_offload_buffers
(
self
,
config
):
def
register_offload_buffers
(
self
,
config
,
lazy_load_path
):
if
config
[
"cpu_offload"
]:
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_blocks_num
=
2
...
@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
)
for
i
in
range
(
self
.
offload_blocks_num
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
]
...
@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
True
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
)
for
i
in
range
(
self
.
offload_blocks_num
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
]
...
@@ -104,22 +98,27 @@ class WanTransformerWeights(WeightModule):
...
@@ -104,22 +98,27 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
).
compute_phases
).
compute_phases
self
.
add_module
(
"offload_phase_cuda_buffers"
,
self
.
offload_phase_cuda_buffers
)
self
.
add_module
(
"offload_phase_cuda_buffers"
,
self
.
offload_phase_cuda_buffers
)
self
.
offload_block_cuda_buffers
=
None
self
.
offload_block_cuda_buffers
=
None
if
self
.
lazy_load
:
if
self
.
lazy_load
:
self
.
offload_phase_cpu_buffers
=
WanTransformerAttentionBlock
(
self
.
offload_phase_cpu_buffers
=
WeightModuleList
(
block_index
=
0
,
[
task
=
self
.
task
,
WanTransformerAttentionBlock
(
mm_type
=
self
.
mm_type
,
block_index
=
i
,
config
=
self
.
config
,
task
=
self
.
task
,
create_cuda_buffer
=
False
,
mm_type
=
self
.
mm_type
,
create_cpu_buffer
=
True
,
config
=
self
.
config
,
block_prefix
=
"blocks"
,
create_cuda_buffer
=
False
,
lazy_load
=
self
.
lazy_load
,
create_cpu_buffer
=
True
,
lazy_load_file
=
self
.
lazy_load_file
,
block_prefix
=
"blocks"
,
).
compute_phases
lazy_load
=
self
.
lazy_load
,
lazy_load_path
=
lazy_load_path
,
).
compute_phases
for
i
in
range
(
2
)
]
)
self
.
add_module
(
"offload_phase_cpu_buffers"
,
self
.
offload_phase_cpu_buffers
)
self
.
add_module
(
"offload_phase_cpu_buffers"
,
self
.
offload_phase_cpu_buffers
)
self
.
offload_block_cpu_buffers
=
None
self
.
offload_block_cpu_buffers
=
None
...
@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
False
,
lazy_load
=
False
,
lazy_load_
file
=
None
,
lazy_load_
path
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
...
@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
if
self
.
lazy_load
:
self
.
lazy_load_file
=
lazy_load_path
else
:
self
.
lazy_load_file
=
None
self
.
compute_phases
=
WeightModuleList
(
self
.
compute_phases
=
WeightModuleList
(
[
[
...
...
lightx2v/models/runners/default_runner.py
View file @
f3b4ba24
...
@@ -185,7 +185,19 @@ class DefaultRunner(BaseRunner):
...
@@ -185,7 +185,19 @@ class DefaultRunner(BaseRunner):
del
self
.
inputs
del
self
.
inputs
self
.
input_info
=
None
self
.
input_info
=
None
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
del
self
.
model
if
hasattr
(
self
.
model
,
"model"
)
and
len
(
self
.
model
.
model
)
==
2
:
# MultiModelStruct
for
model
in
self
.
model
.
model
:
if
hasattr
(
model
.
transformer_infer
,
"offload_manager"
):
del
model
.
transformer_infer
.
offload_manager
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
del
model
else
:
if
hasattr
(
self
.
model
.
transformer_infer
,
"offload_manager"
):
del
self
.
model
.
transformer_infer
.
offload_manager
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
del
self
.
model
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
torch
.
save
(
CALIB
,
calib_path
)
torch
.
save
(
CALIB
,
calib_path
)
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
f3b4ba24
...
@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct):
...
@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct):
self
.
to_cuda
(
model_index
=
1
)
self
.
to_cuda
(
model_index
=
1
)
self
.
cur_model_index
=
1
self
.
cur_model_index
=
1
def
infer
(
self
,
inputs
):
self
.
get_current_model_index
()
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
model
[
self
.
cur_model_index
]
is
not
None
:
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
cur_model_index
==
0
:
high_noise_model
=
WanDistillModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
0
]
=
high_noise_model
self
.
model
[
0
].
infer
(
inputs
)
elif
self
.
cur_model_index
==
1
:
low_noise_model
=
WanDistillModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
1
]
=
low_noise_model
self
.
model
[
1
].
infer
(
inputs
)
@
RUNNER_REGISTER
(
"wan2.2_moe_distill"
)
@
RUNNER_REGISTER
(
"wan2.2_moe_distill"
)
class
Wan22MoeDistillRunner
(
WanDistillRunner
):
class
Wan22MoeDistillRunner
(
WanDistillRunner
):
...
@@ -101,61 +130,68 @@ class Wan22MoeDistillRunner(WanDistillRunner):
...
@@ -101,61 +130,68 @@ class Wan22MoeDistillRunner(WanDistillRunner):
raise
FileNotFoundError
(
f
"Low Noise Model does not find"
)
raise
FileNotFoundError
(
f
"Low Noise Model does not find"
)
def
load_transformer
(
self
):
def
load_transformer
(
self
):
use_high_lora
,
use_low_lora
=
False
,
False
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
use_high_lora
,
use_low_lora
=
False
,
False
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
for
lora_config
in
self
.
config
[
"lora_configs"
]:
use_high_lora
=
True
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
elif
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
use_high_lora
=
True
use_low_lora
=
True
elif
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
use_low_lora
=
True
if
use_high_lora
:
high_noise_model
=
WanModel
(
if
use_high_lora
:
self
.
high_noise_model_path
,
high_noise_model
=
WanModel
(
self
.
config
,
self
.
high_noise_model_path
,
self
.
init_device
,
self
.
config
,
model_type
=
"wan2.2_moe_high_noise"
,
self
.
init_device
,
)
model_type
=
"wan2.2_moe_high_noise"
,
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
high_lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
for
lora_config
in
self
.
config
[
"lora_configs"
]:
lora_path
=
lora_config
[
"path"
]
if
lora_config
.
get
(
"name"
,
""
)
==
"high_noise_model"
:
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_path
=
lora_config
[
"path"
]
lora_name
=
high_lora_wrapper
.
load_lora
(
lora_path
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
high_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
lora_name
=
high_lora_wrapper
.
load_lora
(
lora_path
)
logger
.
info
(
f
"High noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
high_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"High noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
high_noise_model
=
WanDistillModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
if
use_low_lora
:
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
low_lora_wrapper
.
load_lora
(
lora_path
)
low_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Low noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
low_noise_model
=
WanDistillModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
else
:
else
:
high_noise_model
=
WanDistillModel
(
model_struct
=
MultiDistillModelStruct
([
None
,
None
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
self
.
high_noise_model_path
,
model_struct
.
low_noise_model_path
=
self
.
low_noise_model_path
self
.
config
,
model_struct
.
high_noise_model_path
=
self
.
high_noise_model_path
self
.
init_device
,
model_struct
.
init_device
=
self
.
init_device
model_type
=
"wan2.2_moe_high_noise"
,
return
model_struct
)
if
use_low_lora
:
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
lora_config
.
get
(
"name"
,
""
)
==
"low_noise_model"
:
lora_path
=
lora_config
[
"path"
]
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
low_lora_wrapper
.
load_lora
(
lora_path
)
low_lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Low noise model loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
else
:
low_noise_model
=
WanDistillModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
f3b4ba24
...
@@ -468,11 +468,37 @@ class MultiModelStruct:
...
@@ -468,11 +468,37 @@ class MultiModelStruct:
def
set_scheduler
(
self
,
shared_scheduler
):
def
set_scheduler
(
self
,
shared_scheduler
):
self
.
scheduler
=
shared_scheduler
self
.
scheduler
=
shared_scheduler
for
model
in
self
.
model
:
for
model
in
self
.
model
:
model
.
set_scheduler
(
shared_scheduler
)
if
model
is
not
None
:
model
.
set_scheduler
(
shared_scheduler
)
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
self
.
get_current_model_index
()
self
.
get_current_model_index
()
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
model
[
self
.
cur_model_index
]
is
not
None
:
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
cur_model_index
==
0
:
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
0
]
=
high_noise_model
self
.
model
[
0
].
infer
(
inputs
)
elif
self
.
cur_model_index
==
1
:
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
1
]
=
low_noise_model
self
.
model
[
1
].
infer
(
inputs
)
@
ProfilingContext4DebugL2
(
"Swtich models in infer_main costs"
)
@
ProfilingContext4DebugL2
(
"Swtich models in infer_main costs"
)
def
get_current_model_index
(
self
):
def
get_current_model_index
(
self
):
...
@@ -526,40 +552,47 @@ class Wan22MoeRunner(WanRunner):
...
@@ -526,40 +552,47 @@ class Wan22MoeRunner(WanRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
high_noise_model
=
WanModel
(
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
high_noise_model_path
,
high_noise_model
=
WanModel
(
self
.
config
,
self
.
high_noise_model_path
,
self
.
init_device
,
self
.
config
,
model_type
=
"wan2.2_moe_high_noise"
,
self
.
init_device
,
)
model_type
=
"wan2.2_moe_high_noise"
,
low_noise_model
=
WanModel
(
)
self
.
low_noise_model_path
,
low_noise_model
=
WanModel
(
self
.
config
,
self
.
low_noise_model_path
,
self
.
init_device
,
self
.
config
,
model_type
=
"wan2.2_moe_low_noise"
,
self
.
init_device
,
)
model_type
=
"wan2.2_moe_low_noise"
,
)
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
lora_path
=
lora_config
[
"path"
]
assert
not
self
.
config
.
get
(
"dit_quantized"
,
False
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
base_name
=
os
.
path
.
basename
(
lora_path
)
for
lora_config
in
self
.
config
[
"lora_configs"
]:
if
base_name
.
startswith
(
"high"
):
lora_path
=
lora_config
[
"path"
]
lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
strength
=
lora_config
.
get
(
"strength"
,
1.0
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
base_name
=
os
.
path
.
basename
(
lora_path
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
if
base_name
.
startswith
(
"high"
):
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
lora_wrapper
=
WanLoraWrapper
(
high_noise_model
)
elif
base_name
.
startswith
(
"low"
):
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
elif
base_name
.
startswith
(
"low"
):
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
lora_wrapper
=
WanLoraWrapper
(
low_noise_model
)
else
:
lora_name
=
lora_wrapper
.
load_lora
(
lora_path
)
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
lora_wrapper
.
apply_lora
(
lora_name
,
strength
)
logger
.
info
(
f
"Loaded LoRA:
{
lora_name
}
with strength:
{
strength
}
"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary"
])
else
:
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary"
])
else
:
model_struct
=
MultiModelStruct
([
None
,
None
],
self
.
config
,
self
.
config
[
"boundary"
])
model_struct
.
low_noise_model_path
=
self
.
low_noise_model_path
model_struct
.
high_noise_model_path
=
self
.
high_noise_model_path
model_struct
.
init_device
=
self
.
init_device
return
model_struct
@
RUNNER_REGISTER
(
"wan2.2"
)
@
RUNNER_REGISTER
(
"wan2.2"
)
...
...
lightx2v/utils/profiler.py
View file @
f3b4ba24
import
asyncio
import
asyncio
import
threading
import
time
import
time
from
functools
import
wraps
from
functools
import
wraps
...
@@ -10,6 +11,13 @@ from lightx2v.utils.envs import *
...
@@ -10,6 +11,13 @@ from lightx2v.utils.envs import *
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
_excluded_time_local
=
threading
.
local
()
def
_get_excluded_time_stack
():
if
not
hasattr
(
_excluded_time_local
,
"stack"
):
_excluded_time_local
.
stack
=
[]
return
_excluded_time_local
.
stack
class
_ProfilingContext
:
class
_ProfilingContext
:
...
@@ -32,11 +40,14 @@ class _ProfilingContext:
...
@@ -32,11 +40,14 @@ class _ProfilingContext:
def
__enter__
(
self
):
def
__enter__
(
self
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
_get_excluded_time_stack
().
append
(
0.0
)
return
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
total_elapsed
=
time
.
perf_counter
()
-
self
.
start_time
excluded
=
_get_excluded_time_stack
().
pop
()
elapsed
=
total_elapsed
-
excluded
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
...
@@ -49,11 +60,14 @@ class _ProfilingContext:
...
@@ -49,11 +60,14 @@ class _ProfilingContext:
async
def
__aenter__
(
self
):
async
def
__aenter__
(
self
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
_get_excluded_time_stack
().
append
(
0.0
)
return
self
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
total_elapsed
=
time
.
perf_counter
()
-
self
.
start_time
excluded
=
_get_excluded_time_stack
().
pop
()
elapsed
=
total_elapsed
-
excluded
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
...
@@ -103,6 +117,65 @@ class _NullContext:
...
@@ -103,6 +117,65 @@ class _NullContext:
return
func
return
func
class
_ExcludedProfilingContext
:
"""用于标记应该从外层 profiling 中排除的时间段"""
def
__init__
(
self
,
name
=
None
):
self
.
name
=
name
if
dist
.
is_initialized
():
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
else
:
self
.
rank_info
=
"Single GPU"
def
__enter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
stack
=
_get_excluded_time_stack
()
for
i
in
range
(
len
(
stack
)):
stack
[
i
]
+=
elapsed
if
self
.
name
and
CHECK_PROFILING_DEBUG_LEVEL
(
1
):
logger
.
info
(
f
"[Profile-Excluded]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds (excluded from outer profiling)"
)
return
False
async
def
__aenter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
stack
=
_get_excluded_time_stack
()
for
i
in
range
(
len
(
stack
)):
stack
[
i
]
+=
elapsed
if
self
.
name
and
CHECK_PROFILING_DEBUG_LEVEL
(
1
):
logger
.
info
(
f
"[Profile-Excluded]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds (excluded from outer profiling)"
)
return
False
def
__call__
(
self
,
func
):
if
asyncio
.
iscoroutinefunction
(
func
):
@
wraps
(
func
)
async
def
async_wrapper
(
*
args
,
**
kwargs
):
async
with
self
:
return
await
func
(
*
args
,
**
kwargs
)
return
async_wrapper
else
:
@
wraps
(
func
)
def
sync_wrapper
(
*
args
,
**
kwargs
):
with
self
:
return
func
(
*
args
,
**
kwargs
)
return
sync_wrapper
class
_ProfilingContextL1
(
_ProfilingContext
):
class
_ProfilingContextL1
(
_ProfilingContext
):
"""Level 1 profiling context with Level1_Log prefix."""
"""Level 1 profiling context with Level1_Log prefix."""
...
@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De
...
@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De
"""
"""
ProfilingContext4DebugL1
=
_ProfilingContextL1
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
# if user >= 1, enable profiling
ProfilingContext4DebugL1
=
_ProfilingContextL1
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
# if user >= 1, enable profiling
ProfilingContext4DebugL2
=
_ProfilingContextL2
if
CHECK_PROFILING_DEBUG_LEVEL
(
2
)
else
_NullContext
# if user >= 2, enable profiling
ProfilingContext4DebugL2
=
_ProfilingContextL2
if
CHECK_PROFILING_DEBUG_LEVEL
(
2
)
else
_NullContext
# if user >= 2, enable profiling
ExcludedProfilingContext
=
_ExcludedProfilingContext
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment