Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f3b4ba24
Unverified
Commit
f3b4ba24
authored
Dec 08, 2025
by
Gu Shiqiao
Committed by
GitHub
Dec 08, 2025
Browse files
[Recon] Reconstruct disk-cpu-cuda offload system (#578)
parent
67d6c6c1
Changes
15
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
512 additions
and
213 deletions
+512
-213
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+70
-3
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+73
-45
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+15
-7
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+11
-6
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+11
-6
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
+12
-2
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
...2v/models/networks/wan/infer/offload/transformer_infer.py
+10
-0
lightx2v/models/networks/wan/infer/transformer_infer.py
lightx2v/models/networks/wan/infer/transformer_infer.py
+2
-2
lightx2v/models/networks/wan/infer/utils.py
lightx2v/models/networks/wan/infer/utils.py
+14
-17
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+20
-8
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+27
-25
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+13
-1
lightx2v/models/runners/wan/wan_distill_runner.py
lightx2v/models/runners/wan/wan_distill_runner.py
+90
-54
lightx2v/models/runners/wan/wan_runner.py
lightx2v/models/runners/wan/wan_runner.py
+68
-35
lightx2v/utils/profiler.py
lightx2v/utils/profiler.py
+76
-2
No files found.
lightx2v/common/offload/manager.py
View file @
f3b4ba24
from
concurrent.futures
import
ThreadPoolExecutor
import
torch
import
torch
from
loguru
import
logger
from
packaging.version
import
parse
from
packaging.version
import
parse
from
tqdm
import
tqdm
from
lightx2v.utils.profiler
import
ExcludedProfilingContext
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
...
@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object):
...
@@ -11,6 +16,7 @@ class WeightAsyncStreamManager(object):
self
.
offload_granularity
=
offload_granularity
self
.
offload_granularity
=
offload_granularity
self
.
init_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
init_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
need_init_first_buffer
=
True
self
.
need_init_first_buffer
=
True
self
.
lazy_load
=
False
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
if
AI_DEVICE
==
"cuda"
and
torch_version
>=
parse
(
"2.7"
):
if
AI_DEVICE
==
"cuda"
and
torch_version
>=
parse
(
"2.7"
):
self
.
cuda_load_stream
=
torch_device_module
.
Stream
(
priority
=
1
)
self
.
cuda_load_stream
=
torch_device_module
.
Stream
(
priority
=
1
)
...
@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object):
...
@@ -44,7 +50,7 @@ class WeightAsyncStreamManager(object):
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
init_stream
):
with
torch_device_module
.
stream
(
self
.
init_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
]
[
0
]
.
state_dict
(),
0
,
adapter_block_idx
)
else
:
else
:
if
self
.
offload_granularity
==
"block"
:
if
self
.
offload_granularity
==
"block"
:
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
...
@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object):
...
@@ -64,8 +70,7 @@ class WeightAsyncStreamManager(object):
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cpu_buffers
[
phase_idx
].
load_state_dict_from_disk
(
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
0
][
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
else
:
else
:
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
...
@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object):
...
@@ -80,3 +85,65 @@ class WeightAsyncStreamManager(object):
def
swap_phases
(
self
):
def
swap_phases
(
self
):
self
.
cuda_load_stream
.
synchronize
()
self
.
cuda_load_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
@
ExcludedProfilingContext
(
"🔥 warm_up_cpu_buffers"
)
def
warm_up_cpu_buffers
(
self
,
blocks_num
):
logger
.
info
(
"🔥 Warming up cpu buffers..."
)
for
i
in
tqdm
(
range
(
blocks_num
)):
for
phase
in
self
.
cpu_buffers
[
0
]:
phase
.
load_state_dict_from_disk
(
i
,
None
)
for
phase
in
self
.
cpu_buffers
[
1
]:
phase
.
load_state_dict_from_disk
(
i
,
None
)
for
phase
in
self
.
cpu_buffers
[
0
]:
phase
.
load_state_dict_from_disk
(
0
,
None
)
for
phase
in
self
.
cpu_buffers
[
1
]:
phase
.
load_state_dict_from_disk
(
1
,
None
)
logger
.
info
(
"✅ CPU buffers warm-up completed."
)
def
init_lazy_load
(
self
,
num_workers
=
6
):
self
.
lazy_load
=
True
self
.
executor
=
ThreadPoolExecutor
(
max_workers
=
num_workers
)
self
.
prefetch_futures
=
[]
self
.
prefetch_block_idx
=
-
1
def
start_prefetch_block
(
self
,
block_idx
,
adapter_block_idx
=
None
):
self
.
prefetch_block_idx
=
block_idx
self
.
prefetch_futures
=
[]
for
phase
in
self
.
cpu_buffers
[
1
]:
future
=
self
.
executor
.
submit
(
phase
.
load_state_dict_from_disk
,
block_idx
,
adapter_block_idx
)
self
.
prefetch_futures
.
append
(
future
)
def
swap_cpu_buffers
(
self
):
import
time
wait_start
=
time
.
time
()
already_done
=
all
(
f
.
done
()
for
f
in
self
.
prefetch_futures
)
for
f
in
self
.
prefetch_futures
:
f
.
result
()
wait_time
=
time
.
time
()
-
wait_start
logger
.
debug
(
f
"[Prefetch] block
{
self
.
prefetch_block_idx
}
: wait=
{
wait_time
:.
3
f
}
s, already_done=
{
already_done
}
"
)
self
.
cpu_buffers
=
[
self
.
cpu_buffers
[
1
],
self
.
cpu_buffers
[
0
]]
def
shutdown
(
self
,
wait
=
True
):
"""Shutdown the thread pool executor and wait for all pending tasks to complete."""
if
hasattr
(
self
,
"executor"
)
and
self
.
executor
is
not
None
:
# Wait for all pending futures to complete before shutting down
if
hasattr
(
self
,
"prefetch_futures"
):
for
f
in
self
.
prefetch_futures
:
try
:
if
not
f
.
done
():
f
.
result
()
except
Exception
:
pass
self
.
executor
.
shutdown
(
wait
=
wait
)
self
.
executor
=
None
logger
.
debug
(
"ThreadPoolExecutor shut down successfully."
)
def
__del__
(
self
):
"""Cleanup method to ensure executor is shut down when object is destroyed."""
try
:
if
hasattr
(
self
,
"executor"
)
and
self
.
executor
is
not
None
:
self
.
executor
.
shutdown
(
wait
=
False
)
except
Exception
:
pass
lightx2v/common/ops/mm/mm_weight.py
View file @
f3b4ba24
import
os
import
re
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.ggml_tensor
import
GGMLTensor
from
lightx2v.utils.ggml_tensor
import
GGMLTensor
...
@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate):
...
@@ -128,7 +130,9 @@ class MMWeight(MMWeightTemplate):
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
if
self
.
lazy_load
:
if
self
.
lazy_load
:
return
self
.
lazy_load_file
.
get_tensor
(
source_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
source_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
return
lazy_load_file
.
get_tensor
(
source_name
)
return
weight_dict
[
source_name
]
return
weight_dict
[
source_name
]
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
...
@@ -145,11 +149,14 @@ class MMWeight(MMWeightTemplate):
...
@@ -145,11 +149,14 @@ class MMWeight(MMWeightTemplate):
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
def
_load_cpu_pin_buffers
(
self
):
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
self
.
lazy_load
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate):
...
@@ -197,10 +204,6 @@ class MMWeight(MMWeightTemplate):
else
:
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
assert
adapter_block_index
is
not
None
...
@@ -208,7 +211,14 @@ class MMWeight(MMWeightTemplate):
...
@@ -208,7 +211,14 @@ class MMWeight(MMWeightTemplate):
else
:
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
del
bias_tensor
...
@@ -283,7 +293,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -283,7 +293,13 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
_load_default_tensors
(
weight_dict
)
self
.
_load_default_tensors
(
weight_dict
)
def
_load_cuda_buffers
(
self
,
weight_dict
):
def
_load_cuda_buffers
(
self
,
weight_dict
):
source
=
self
.
lazy_load_file
if
self
.
lazy_load
else
weight_dict
if
self
.
lazy_load
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
else
:
source
=
weight_dict
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
...
@@ -318,14 +334,17 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -318,14 +334,17 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
if
is_lazy
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_dtype
=
torch
.
float
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
else
:
else
:
weight_tensor
=
source
[
self
.
weight_name
]
weight_tensor
=
source
[
self
.
weight_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_dtype
=
torch
.
float
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
return
pin_weight
,
pin_scale
return
pin_weight
,
pin_scale
...
@@ -334,9 +353,14 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -334,9 +353,14 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
self
.
bias_name
is
None
:
if
self
.
bias_name
is
None
:
return
None
return
None
if
is_lazy
:
if
is_lazy
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
source
:
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
if
not
self
.
bias_force_fp32
:
if
not
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
else
:
bias_tensor
=
source
[
self
.
bias_name
]
bias_tensor
=
source
[
self
.
bias_name
]
if
self
.
bias_force_fp32
:
if
self
.
bias_force_fp32
:
...
@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -643,17 +667,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
if
self
.
weight_need_transpose
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
weight_scale_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
assert
adapter_block_index
is
not
None
...
@@ -661,7 +674,22 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -661,7 +674,22 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else
:
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
if
self
.
weight_need_transpose
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
weight_scale_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_scale_tensor
if
self
.
bias_name
is
not
None
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
del
bias_tensor
...
...
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
f3b4ba24
import
os
import
re
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
...
@@ -53,7 +55,9 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -53,7 +55,9 @@ class LNWeightTemplate(metaclass=ABCMeta):
if
name
is
None
:
if
name
is
None
:
return
None
return
None
if
self
.
lazy_load
:
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
name
)
if
use_infer_dtype
:
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
...
@@ -151,23 +155,27 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -151,23 +155,27 @@ class LNWeightTemplate(metaclass=ABCMeta):
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
weight_name
is
not
None
:
if
self
.
weight_name
is
not
None
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
bias_tensor
=
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
del
bias_tensor
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
f3b4ba24
import
os
import
re
import
re
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
...
@@ -46,7 +48,9 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -46,7 +48,9 @@ class RMSWeightTemplate(metaclass=ABCMeta):
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
weight_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
use_infer_dtype
:
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
...
@@ -107,8 +111,9 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -107,8 +111,9 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
weight_tensor
=
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
del
weight_tensor
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
f3b4ba24
import
os
import
re
import
re
import
torch
import
torch
from
safetensors
import
safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
...
@@ -39,7 +41,9 @@ class DefaultTensor:
...
@@ -39,7 +41,9 @@ class DefaultTensor:
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
self
.
tensor_name
.
split
(
'.'
)[
1
]
}
.safetensors"
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
if
use_infer_dtype
:
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
...
@@ -92,7 +96,8 @@ class DefaultTensor:
...
@@ -92,7 +96,8 @@ class DefaultTensor:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
else
:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
lazy_load_file_path
=
os
.
path
.
join
(
self
.
lazy_load_file
,
f
"block_
{
block_index
}
.safetensors"
)
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
with
safe_open
(
lazy_load_file_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
lazy_load_file
:
tensor
=
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
del
tensor
del
tensor
lightx2v/models/networks/hunyuan_video/infer/attn_no_pad.py
View file @
f3b4ba24
import
torch
import
torch
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
from
loguru
import
logger
from
loguru
import
logger
try
:
from
flash_attn
import
flash_attn_varlen_qkvpacked_func
except
ImportError
:
flash_attn_varlen_qkvpacked_func
=
None
logger
.
info
(
"flash_attn_varlen_qkvpacked_func not available"
)
try
:
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
except
ImportError
:
pad_input
=
None
unpad_input
=
None
logger
.
info
(
"flash_attn.bert_padding not available"
)
try
:
try
:
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
from
flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_varlen_func_v3
except
ImportError
:
except
ImportError
:
...
...
lightx2v/models/networks/wan/infer/offload/transformer_infer.py
View file @
f3b4ba24
...
@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -32,6 +32,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
if
offload_granularity
!=
"model"
:
if
offload_granularity
!=
"model"
:
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
self
.
offload_manager
=
WeightAsyncStreamManager
(
offload_granularity
=
offload_granularity
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
self
.
lazy_load
and
offload_granularity
==
"phase"
:
self
.
offload_manager
.
init_lazy_load
(
num_workers
=
self
.
config
.
get
(
"num_disk_workers"
,
4
))
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_blocks_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
...
@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -57,6 +60,10 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
def
infer_with_phases_offload
(
self
,
blocks
,
x
,
pre_infer_out
):
for
block_idx
in
range
(
len
(
blocks
)):
for
block_idx
in
range
(
len
(
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
self
.
lazy_load
:
next_prefetch
=
(
block_idx
+
1
)
%
len
(
blocks
)
self
.
offload_manager
.
start_prefetch_block
(
next_prefetch
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
)
x
=
self
.
infer_phases
(
block_idx
,
blocks
,
x
,
pre_infer_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
(
del
(
...
@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
...
@@ -77,6 +84,9 @@ class WanOffloadTransformerInfer(WanTransformerInfer):
self
.
offload_manager
.
init_first_buffer
(
blocks
)
self
.
offload_manager
.
init_first_buffer
(
blocks
)
next_block_idx
=
(
block_idx
+
1
)
%
len
(
blocks
)
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_block_idx
=
(
block_idx
+
1
)
%
len
(
blocks
)
if
phase_idx
==
self
.
phases_num
-
1
else
block_idx
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
next_phase_idx
=
(
phase_idx
+
1
)
%
self
.
phases_num
if
self
.
lazy_load
:
if
phase_idx
==
self
.
phases_num
-
1
:
self
.
offload_manager
.
swap_cpu_buffers
()
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
self
.
offload_manager
.
prefetch_phase
(
next_block_idx
,
next_phase_idx
,
blocks
)
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
with
torch_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
x
=
self
.
infer_phase
(
phase_idx
,
self
.
offload_manager
.
cuda_buffers
[
phase_idx
],
x
,
pre_infer_out
)
...
...
lightx2v/models/networks/wan/infer/transformer_infer.py
View file @
f3b4ba24
...
@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -171,7 +171,7 @@ class WanTransformerInfer(BaseTransformerInfer):
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
cu_seqlens_qkv
=
torch
.
tensor
([
0
,
img_qkv_len
],
dtype
=
torch
.
int32
,
device
=
"cpu"
).
to
(
q
.
device
,
non_blocking
=
True
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
norm1_out
,
norm1_weight
,
norm1_bias
del
norm1_out
,
shift_msa
,
scale_msa
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
self
.
config
[
"seq_parallel"
]:
if
self
.
config
[
"seq_parallel"
]:
...
@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer):
...
@@ -300,7 +300,7 @@ class WanTransformerInfer(BaseTransformerInfer):
y
=
phase
.
ffn_0
.
apply
(
norm2_out
)
y
=
phase
.
ffn_0
.
apply
(
norm2_out
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
del
norm2_out
,
x
,
norm2_weight
,
norm2_bias
del
norm2_out
,
x
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
y
=
torch
.
nn
.
functional
.
gelu
(
y
,
approximate
=
"tanh"
)
if
self
.
clean_cuda_cache
:
if
self
.
clean_cuda_cache
:
...
...
lightx2v/models/networks/wan/infer/utils.py
View file @
f3b4ba24
...
@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk(
...
@@ -36,29 +36,26 @@ def apply_wan_rope_with_chunk(
rope_func
,
rope_func
,
):
):
seq_len
=
cos_sin_cache
.
size
(
0
)
seq_len
=
cos_sin_cache
.
size
(
0
)
x_q
=
torch
.
empty_like
(
xq
)
x_k
=
torch
.
empty_like
(
xk
)
xq_output_chunks
=
[]
xk_output_chunks
=
[]
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
for
start
in
range
(
0
,
seq_len
,
chunk_size
):
end
=
min
(
start
+
chunk_size
,
seq_len
)
end
=
min
(
start
+
chunk_size
,
seq_len
)
xq_chunk
=
xq
[
start
:
end
]
xq_chunk
=
xq
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
xk_chunk
=
xk
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
cos_sin_chunk
=
cos_sin_cache
[
start
:
end
]
xq_chunk_out
,
xk_chunk_out
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
xq_chunk
,
xk_chunk
=
rope_func
(
xq_chunk
,
xk_chunk
,
cos_sin_chunk
)
x_q
[
start
:
end
].
copy_
(
xq_chunk_out
,
non_blocking
=
True
)
xq_output_chunks
.
append
(
xq_chunk
)
x_k
[
start
:
end
].
copy_
(
xk_chunk_out
,
non_blocking
=
True
)
xk_output_chunks
.
append
(
xk_chunk
)
del
xq_chunk_out
,
xk_chunk_out
torch
.
cuda
.
empty_cache
()
target_dtype
=
GET_DTYPE
()
x_q
=
torch
.
cat
(
xq_output_chunks
,
dim
=
0
)
if
x_q
.
dtype
!=
target_dtype
:
del
xq_output_chunks
x_q
=
x_q
.
to
(
target_dtype
)
torch
.
cuda
.
empty_cache
()
if
x_k
.
dtype
!=
target_dtype
:
x_k
=
x_k
.
to
(
target_dtype
)
x_k
=
torch
.
cat
(
xk_output_chunks
,
dim
=
0
)
del
xk_output_chunks
return
x_q
,
x_k
torch
.
cuda
.
empty_cache
()
return
x_q
.
to
(
GET_DTYPE
()),
x_k
.
to
(
GET_DTYPE
())
def
apply_wan_rope_with_flashinfer
(
def
apply_wan_rope_with_flashinfer
(
...
...
lightx2v/models/networks/wan/model.py
View file @
f3b4ba24
...
@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin):
...
@@ -173,8 +173,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_files
=
[
safetensors_path
]
safetensors_files
=
[
safetensors_path
]
if
self
.
lazy_load
:
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_path
self
.
lazy_load_path
=
safetensors_files
[
0
]
non_block_file
=
os
.
path
.
join
(
safetensors_path
,
"non_block.safetensors"
)
if
os
.
path
.
exists
(
non_block_file
):
safetensors_files
=
[
non_block_file
]
else
:
raise
ValueError
(
f
"Non-block file not found in
{
safetensors_path
}
"
)
weight_dict
=
{}
weight_dict
=
{}
for
file_path
in
safetensors_files
:
for
file_path
in
safetensors_files
:
...
@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin):
...
@@ -189,7 +193,6 @@ class WanModel(CompiledMethodsMixin):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_quant_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
):
if
self
.
config
.
get
(
"dit_quantized_ckpt"
,
None
):
safetensors_path
=
self
.
config
[
"dit_quantized_ckpt"
]
safetensors_path
=
self
.
config
[
"dit_quantized_ckpt"
]
else
:
else
:
...
@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin):
...
@@ -213,8 +216,12 @@ class WanModel(CompiledMethodsMixin):
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
if
self
.
lazy_load
:
if
self
.
lazy_load
:
assert
len
(
safetensors_files
)
==
1
,
"Only support single safetensors file in lazy load mode"
self
.
lazy_load_path
=
safetensors_path
self
.
lazy_load_path
=
safetensors_files
[
0
]
non_block_file
=
os
.
path
.
join
(
safetensors_path
,
"non_block.safetensors"
)
if
os
.
path
.
exists
(
non_block_file
):
safetensors_files
=
[
non_block_file
]
else
:
raise
ValueError
(
f
"Non-block file not found in
{
safetensors_path
}
, Please check the lazy load model path"
)
weight_dict
=
{}
weight_dict
=
{}
for
safetensor_path
in
safetensors_files
:
for
safetensor_path
in
safetensors_files
:
...
@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin):
...
@@ -372,9 +379,14 @@ class WanModel(CompiledMethodsMixin):
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
_init_offload_manager
()
def
_init_offload_manager
(
self
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_cuda_buffers
,
self
.
transformer_weights
.
offload_phase_cuda_buffers
)
if
self
.
lazy_load
:
if
self
.
lazy_load
:
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cpu_buffer
(
self
.
transformer_weights
.
offload_block_cpu_buffers
,
self
.
transformer_weights
.
offload_phase_cpu_buffers
)
if
self
.
config
.
get
(
"warm_up_cpu_buffers"
,
False
):
self
.
transformer_infer
.
offload_manager
.
warm_up_cpu_buffers
(
self
.
transformer_weights
.
blocks_num
)
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
f3b4ba24
from
safetensors
import
safe_open
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.common.modules.weight_module
import
WeightModule
,
WeightModuleList
from
lightx2v.utils.registry_factory
import
(
from
lightx2v.utils.registry_factory
import
(
ATTN_WEIGHT_REGISTER
,
ATTN_WEIGHT_REGISTER
,
...
@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule):
...
@@ -22,10 +20,6 @@ class WanTransformerWeights(WeightModule):
if
config
.
get
(
"do_mm_calib"
,
False
):
if
config
.
get
(
"do_mm_calib"
,
False
):
self
.
mm_type
=
"Calib"
self
.
mm_type
=
"Calib"
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
self
.
lazy_load
=
self
.
config
.
get
(
"lazy_load"
,
False
)
if
not
self
.
lazy_load
:
self
.
lazy_load_file
=
None
else
:
self
.
lazy_load_file
=
safe_open
(
lazy_load_path
,
framework
=
"pt"
,
device
=
"cpu"
)
self
.
blocks
=
WeightModuleList
(
self
.
blocks
=
WeightModuleList
(
[
[
WanTransformerAttentionBlock
(
WanTransformerAttentionBlock
(
...
@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule):
...
@@ -37,12 +31,12 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
)
for
i
in
range
(
self
.
blocks_num
)
for
i
in
range
(
self
.
blocks_num
)
]
]
)
)
self
.
register_offload_buffers
(
config
)
self
.
register_offload_buffers
(
config
,
lazy_load_path
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
# non blocks weights
# non blocks weights
...
@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -50,7 +44,7 @@ class WanTransformerWeights(WeightModule):
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
add_module
(
"head"
,
MM_WEIGHT_REGISTER
[
"Default"
](
"head.head.weight"
,
"head.head.bias"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
self
.
register_parameter
(
"head_modulation"
,
TENSOR_REGISTER
[
"Default"
](
"head.modulation"
))
def
register_offload_buffers
(
self
,
config
):
def
register_offload_buffers
(
self
,
config
,
lazy_load_path
):
if
config
[
"cpu_offload"
]:
if
config
[
"cpu_offload"
]:
if
config
[
"offload_granularity"
]
==
"block"
:
if
config
[
"offload_granularity"
]
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_blocks_num
=
2
...
@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -65,7 +59,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
)
for
i
in
range
(
self
.
offload_blocks_num
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
]
...
@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule):
...
@@ -86,7 +80,7 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
True
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
)
)
for
i
in
range
(
self
.
offload_blocks_num
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
]
...
@@ -104,13 +98,15 @@ class WanTransformerWeights(WeightModule):
...
@@ -104,13 +98,15 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
).
compute_phases
).
compute_phases
self
.
add_module
(
"offload_phase_cuda_buffers"
,
self
.
offload_phase_cuda_buffers
)
self
.
add_module
(
"offload_phase_cuda_buffers"
,
self
.
offload_phase_cuda_buffers
)
self
.
offload_block_cuda_buffers
=
None
self
.
offload_block_cuda_buffers
=
None
if
self
.
lazy_load
:
if
self
.
lazy_load
:
self
.
offload_phase_cpu_buffers
=
WanTransformerAttentionBlock
(
self
.
offload_phase_cpu_buffers
=
WeightModuleList
(
block_index
=
0
,
[
WanTransformerAttentionBlock
(
block_index
=
i
,
task
=
self
.
task
,
task
=
self
.
task
,
mm_type
=
self
.
mm_type
,
mm_type
=
self
.
mm_type
,
config
=
self
.
config
,
config
=
self
.
config
,
...
@@ -118,8 +114,11 @@ class WanTransformerWeights(WeightModule):
...
@@ -118,8 +114,11 @@ class WanTransformerWeights(WeightModule):
create_cpu_buffer
=
True
,
create_cpu_buffer
=
True
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
self
.
lazy_load
,
lazy_load
=
self
.
lazy_load
,
lazy_load_
file
=
self
.
lazy_load_
file
,
lazy_load_
path
=
lazy_load_
path
,
).
compute_phases
).
compute_phases
for
i
in
range
(
2
)
]
)
self
.
add_module
(
"offload_phase_cpu_buffers"
,
self
.
offload_phase_cpu_buffers
)
self
.
add_module
(
"offload_phase_cpu_buffers"
,
self
.
offload_phase_cpu_buffers
)
self
.
offload_block_cpu_buffers
=
None
self
.
offload_block_cpu_buffers
=
None
...
@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -145,7 +144,7 @@ class WanTransformerAttentionBlock(WeightModule):
create_cpu_buffer
=
False
,
create_cpu_buffer
=
False
,
block_prefix
=
"blocks"
,
block_prefix
=
"blocks"
,
lazy_load
=
False
,
lazy_load
=
False
,
lazy_load_
file
=
None
,
lazy_load_
path
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
...
@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule):
...
@@ -157,7 +156,10 @@ class WanTransformerAttentionBlock(WeightModule):
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
quant_method
=
config
.
get
(
"quant_method"
,
None
)
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
if
self
.
lazy_load
:
self
.
lazy_load_file
=
lazy_load_path
else
:
self
.
lazy_load_file
=
None
self
.
compute_phases
=
WeightModuleList
(
self
.
compute_phases
=
WeightModuleList
(
[
[
...
...
lightx2v/models/runners/default_runner.py
View file @
f3b4ba24
...
@@ -185,6 +185,18 @@ class DefaultRunner(BaseRunner):
...
@@ -185,6 +185,18 @@ class DefaultRunner(BaseRunner):
del
self
.
inputs
del
self
.
inputs
self
.
input_info
=
None
self
.
input_info
=
None
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
hasattr
(
self
.
model
,
"model"
)
and
len
(
self
.
model
.
model
)
==
2
:
# MultiModelStruct
for
model
in
self
.
model
.
model
:
if
hasattr
(
model
.
transformer_infer
,
"offload_manager"
):
del
model
.
transformer_infer
.
offload_manager
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
del
model
else
:
if
hasattr
(
self
.
model
.
transformer_infer
,
"offload_manager"
):
del
self
.
model
.
transformer_infer
.
offload_manager
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
del
self
.
model
del
self
.
model
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
...
...
lightx2v/models/runners/wan/wan_distill_runner.py
View file @
f3b4ba24
...
@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct):
...
@@ -73,6 +73,35 @@ class MultiDistillModelStruct(MultiModelStruct):
self
.
to_cuda
(
model_index
=
1
)
self
.
to_cuda
(
model_index
=
1
)
self
.
cur_model_index
=
1
self
.
cur_model_index
=
1
def
infer
(
self
,
inputs
):
self
.
get_current_model_index
()
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
model
[
self
.
cur_model_index
]
is
not
None
:
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
cur_model_index
==
0
:
high_noise_model
=
WanDistillModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
0
]
=
high_noise_model
self
.
model
[
0
].
infer
(
inputs
)
elif
self
.
cur_model_index
==
1
:
low_noise_model
=
WanDistillModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
1
]
=
low_noise_model
self
.
model
[
1
].
infer
(
inputs
)
@
RUNNER_REGISTER
(
"wan2.2_moe_distill"
)
@
RUNNER_REGISTER
(
"wan2.2_moe_distill"
)
class
Wan22MoeDistillRunner
(
WanDistillRunner
):
class
Wan22MoeDistillRunner
(
WanDistillRunner
):
...
@@ -101,6 +130,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
...
@@ -101,6 +130,7 @@ class Wan22MoeDistillRunner(WanDistillRunner):
raise
FileNotFoundError
(
f
"Low Noise Model does not find"
)
raise
FileNotFoundError
(
f
"Low Noise Model does not find"
)
def
load_transformer
(
self
):
def
load_transformer
(
self
):
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
use_high_lora
,
use_low_lora
=
False
,
False
use_high_lora
,
use_low_lora
=
False
,
False
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
if
self
.
config
.
get
(
"lora_configs"
)
and
self
.
config
[
"lora_configs"
]:
for
lora_config
in
self
.
config
[
"lora_configs"
]:
for
lora_config
in
self
.
config
[
"lora_configs"
]:
...
@@ -156,6 +186,12 @@ class Wan22MoeDistillRunner(WanDistillRunner):
...
@@ -156,6 +186,12 @@ class Wan22MoeDistillRunner(WanDistillRunner):
)
)
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
return
MultiDistillModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
else
:
model_struct
=
MultiDistillModelStruct
([
None
,
None
],
self
.
config
,
self
.
config
[
"boundary_step_index"
])
model_struct
.
low_noise_model_path
=
self
.
low_noise_model_path
model_struct
.
high_noise_model_path
=
self
.
high_noise_model_path
model_struct
.
init_device
=
self
.
init_device
return
model_struct
def
init_scheduler
(
self
):
def
init_scheduler
(
self
):
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
if
self
.
config
[
"feature_caching"
]
==
"NoCaching"
:
...
...
lightx2v/models/runners/wan/wan_runner.py
View file @
f3b4ba24
...
@@ -468,11 +468,37 @@ class MultiModelStruct:
...
@@ -468,11 +468,37 @@ class MultiModelStruct:
def
set_scheduler
(
self
,
shared_scheduler
):
def
set_scheduler
(
self
,
shared_scheduler
):
self
.
scheduler
=
shared_scheduler
self
.
scheduler
=
shared_scheduler
for
model
in
self
.
model
:
for
model
in
self
.
model
:
if
model
is
not
None
:
model
.
set_scheduler
(
shared_scheduler
)
model
.
set_scheduler
(
shared_scheduler
)
def
infer
(
self
,
inputs
):
def
infer
(
self
,
inputs
):
self
.
get_current_model_index
()
self
.
get_current_model_index
()
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
model
[
self
.
cur_model_index
]
is
not
None
:
self
.
model
[
self
.
cur_model_index
].
infer
(
inputs
)
else
:
if
self
.
cur_model_index
==
0
:
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_high_noise"
,
)
high_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
0
]
=
high_noise_model
self
.
model
[
0
].
infer
(
inputs
)
elif
self
.
cur_model_index
==
1
:
low_noise_model
=
WanModel
(
self
.
low_noise_model_path
,
self
.
config
,
self
.
init_device
,
model_type
=
"wan2.2_moe_low_noise"
,
)
low_noise_model
.
set_scheduler
(
self
.
scheduler
)
self
.
model
[
1
]
=
low_noise_model
self
.
model
[
1
].
infer
(
inputs
)
@
ProfilingContext4DebugL2
(
"Swtich models in infer_main costs"
)
@
ProfilingContext4DebugL2
(
"Swtich models in infer_main costs"
)
def
get_current_model_index
(
self
):
def
get_current_model_index
(
self
):
...
@@ -526,6 +552,7 @@ class Wan22MoeRunner(WanRunner):
...
@@ -526,6 +552,7 @@ class Wan22MoeRunner(WanRunner):
def
load_transformer
(
self
):
def
load_transformer
(
self
):
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
# encoder -> high_noise_model -> low_noise_model -> vae -> video_output
if
not
self
.
config
.
get
(
"lazy_load"
,
False
)
and
not
self
.
config
.
get
(
"unload_modules"
,
False
):
high_noise_model
=
WanModel
(
high_noise_model
=
WanModel
(
self
.
high_noise_model_path
,
self
.
high_noise_model_path
,
self
.
config
,
self
.
config
,
...
@@ -560,6 +587,12 @@ class Wan22MoeRunner(WanRunner):
...
@@ -560,6 +587,12 @@ class Wan22MoeRunner(WanRunner):
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
raise
ValueError
(
f
"Unsupported LoRA path:
{
lora_path
}
"
)
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary"
])
return
MultiModelStruct
([
high_noise_model
,
low_noise_model
],
self
.
config
,
self
.
config
[
"boundary"
])
else
:
model_struct
=
MultiModelStruct
([
None
,
None
],
self
.
config
,
self
.
config
[
"boundary"
])
model_struct
.
low_noise_model_path
=
self
.
low_noise_model_path
model_struct
.
high_noise_model_path
=
self
.
high_noise_model_path
model_struct
.
init_device
=
self
.
init_device
return
model_struct
@
RUNNER_REGISTER
(
"wan2.2"
)
@
RUNNER_REGISTER
(
"wan2.2"
)
...
...
lightx2v/utils/profiler.py
View file @
f3b4ba24
import
asyncio
import
asyncio
import
threading
import
time
import
time
from
functools
import
wraps
from
functools
import
wraps
...
@@ -10,6 +11,13 @@ from lightx2v.utils.envs import *
...
@@ -10,6 +11,13 @@ from lightx2v.utils.envs import *
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
_excluded_time_local
=
threading
.
local
()
def
_get_excluded_time_stack
():
if
not
hasattr
(
_excluded_time_local
,
"stack"
):
_excluded_time_local
.
stack
=
[]
return
_excluded_time_local
.
stack
class
_ProfilingContext
:
class
_ProfilingContext
:
...
@@ -32,11 +40,14 @@ class _ProfilingContext:
...
@@ -32,11 +40,14 @@ class _ProfilingContext:
def
__enter__
(
self
):
def
__enter__
(
self
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
_get_excluded_time_stack
().
append
(
0.0
)
return
self
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
total_elapsed
=
time
.
perf_counter
()
-
self
.
start_time
excluded
=
_get_excluded_time_stack
().
pop
()
elapsed
=
total_elapsed
-
excluded
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
...
@@ -49,11 +60,14 @@ class _ProfilingContext:
...
@@ -49,11 +60,14 @@ class _ProfilingContext:
async
def
__aenter__
(
self
):
async
def
__aenter__
(
self
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
self
.
start_time
=
time
.
perf_counter
()
_get_excluded_time_stack
().
append
(
0.0
)
return
self
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
total_elapsed
=
time
.
perf_counter
()
-
self
.
start_time
excluded
=
_get_excluded_time_stack
().
pop
()
elapsed
=
total_elapsed
-
excluded
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
enable_recorder
and
self
.
metrics_func
:
if
self
.
metrics_labels
:
if
self
.
metrics_labels
:
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
self
.
metrics_func
.
labels
(
*
self
.
metrics_labels
).
observe
(
elapsed
)
...
@@ -103,6 +117,65 @@ class _NullContext:
...
@@ -103,6 +117,65 @@ class _NullContext:
return
func
return
func
class
_ExcludedProfilingContext
:
"""用于标记应该从外层 profiling 中排除的时间段"""
def
__init__
(
self
,
name
=
None
):
self
.
name
=
name
if
dist
.
is_initialized
():
self
.
rank_info
=
f
"Rank
{
dist
.
get_rank
()
}
"
else
:
self
.
rank_info
=
"Single GPU"
def
__enter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
def
__exit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
stack
=
_get_excluded_time_stack
()
for
i
in
range
(
len
(
stack
)):
stack
[
i
]
+=
elapsed
if
self
.
name
and
CHECK_PROFILING_DEBUG_LEVEL
(
1
):
logger
.
info
(
f
"[Profile-Excluded]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds (excluded from outer profiling)"
)
return
False
async
def
__aenter__
(
self
):
torch_device_module
.
synchronize
()
self
.
start_time
=
time
.
perf_counter
()
return
self
async
def
__aexit__
(
self
,
exc_type
,
exc_val
,
exc_tb
):
torch_device_module
.
synchronize
()
elapsed
=
time
.
perf_counter
()
-
self
.
start_time
stack
=
_get_excluded_time_stack
()
for
i
in
range
(
len
(
stack
)):
stack
[
i
]
+=
elapsed
if
self
.
name
and
CHECK_PROFILING_DEBUG_LEVEL
(
1
):
logger
.
info
(
f
"[Profile-Excluded]
{
self
.
rank_info
}
-
{
self
.
name
}
cost
{
elapsed
:.
6
f
}
seconds (excluded from outer profiling)"
)
return
False
def
__call__
(
self
,
func
):
if
asyncio
.
iscoroutinefunction
(
func
):
@
wraps
(
func
)
async
def
async_wrapper
(
*
args
,
**
kwargs
):
async
with
self
:
return
await
func
(
*
args
,
**
kwargs
)
return
async_wrapper
else
:
@
wraps
(
func
)
def
sync_wrapper
(
*
args
,
**
kwargs
):
with
self
:
return
func
(
*
args
,
**
kwargs
)
return
sync_wrapper
class
_ProfilingContextL1
(
_ProfilingContext
):
class
_ProfilingContextL1
(
_ProfilingContext
):
"""Level 1 profiling context with Level1_Log prefix."""
"""Level 1 profiling context with Level1_Log prefix."""
...
@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De
...
@@ -124,3 +197,4 @@ PROFILING_DEBUG_LEVEL=2: enable ProfilingContext4DebugL1 and ProfilingContext4De
"""
"""
ProfilingContext4DebugL1
=
_ProfilingContextL1
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
# if user >= 1, enable profiling
ProfilingContext4DebugL1
=
_ProfilingContextL1
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
# if user >= 1, enable profiling
ProfilingContext4DebugL2
=
_ProfilingContextL2
if
CHECK_PROFILING_DEBUG_LEVEL
(
2
)
else
_NullContext
# if user >= 2, enable profiling
ProfilingContext4DebugL2
=
_ProfilingContextL2
if
CHECK_PROFILING_DEBUG_LEVEL
(
2
)
else
_NullContext
# if user >= 2, enable profiling
ExcludedProfilingContext
=
_ExcludedProfilingContext
if
CHECK_PROFILING_DEBUG_LEVEL
(
1
)
else
_NullContext
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment