Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
74eeb429
Unverified
Commit
74eeb429
authored
Dec 03, 2025
by
Gu Shiqiao
Committed by
GitHub
Dec 03, 2025
Browse files
reconstruct disk offload and fix lightx2v_platform bugs (#558)
Co-authored-by:
helloyongyang
<
yongyang1030@163.com
>
parent
f7cdbcb5
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
671 additions
and
826 deletions
+671
-826
lightx2v/common/modules/weight_module.py
lightx2v/common/modules/weight_module.py
+8
-29
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+46
-362
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+4
-3
lightx2v/common/ops/attn/template.py
lightx2v/common/ops/attn/template.py
+3
-0
lightx2v/common/ops/conv/conv2d.py
lightx2v/common/ops/conv/conv2d.py
+5
-11
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+3
-9
lightx2v/common/ops/embedding/embedding_weight.py
lightx2v/common/ops/embedding/embedding_weight.py
+5
-3
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+287
-170
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+93
-85
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+73
-56
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+51
-29
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
...t_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
+5
-9
lightx2v/models/input_encoders/hf/wan/t5/model.py
lightx2v/models/input_encoders/hf/wan/t5/model.py
+9
-16
lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py
.../hunyuan_video/infer/feature_caching/transformer_infer.py
+3
-2
lightx2v/models/networks/hunyuan_video/infer/offload/transformer_infer.py
...networks/hunyuan_video/infer/offload/transformer_infer.py
+4
-1
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
.../models/networks/hunyuan_video/infer/transformer_infer.py
+2
-7
lightx2v/models/networks/hunyuan_video/model.py
lightx2v/models/networks/hunyuan_video/model.py
+4
-3
lightx2v/models/networks/hunyuan_video/weights/transformer_weights.py
...els/networks/hunyuan_video/weights/transformer_weights.py
+60
-27
lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py
...ls/networks/qwen_image/infer/offload/transformer_infer.py
+4
-1
lightx2v/models/networks/qwen_image/model.py
lightx2v/models/networks/qwen_image/model.py
+2
-3
No files found.
lightx2v/common/modules/weight_module.py
View file @
74eeb429
...
@@ -23,35 +23,6 @@ class WeightModule:
...
@@ -23,35 +23,6 @@ class WeightModule:
if
hasattr
(
parameter
,
"load"
):
if
hasattr
(
parameter
,
"load"
):
parameter
.
load
(
weight_dict
)
parameter
.
load
(
weight_dict
)
def
calculate_size
(
self
):
total_size
=
0
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"_calculate_size"
):
total_size
+=
module
.
_calculate_size
()
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"_calculate_size"
):
total_size
+=
parameter
.
_calculate_size
()
return
total_size
def
load_from_disk
(
self
):
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"load_from_disk"
):
module
.
load_from_disk
()
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"load_from_disk"
):
parameter
.
load_from_disk
()
def
clear
(
self
):
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"clear"
):
module
.
clear
()
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"clear"
):
parameter
.
clear
()
def
state_dict
(
self
,
destination
=
None
):
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
if
destination
is
None
:
destination
=
{}
destination
=
{}
...
@@ -74,6 +45,14 @@ class WeightModule:
...
@@ -74,6 +45,14 @@ class WeightModule:
module
.
load_state_dict
(
destination
,
block_index
,
adapter_block_index
)
module
.
load_state_dict
(
destination
,
block_index
,
adapter_block_index
)
return
destination
return
destination
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
for
_
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
param
.
load_state_dict_from_disk
(
block_index
,
adapter_block_index
)
for
_
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
module
.
load_state_dict_from_disk
(
block_index
,
adapter_block_index
)
def
named_parameters
(
self
,
prefix
=
""
):
def
named_parameters
(
self
,
prefix
=
""
):
for
name
,
param
in
self
.
_parameters
.
items
():
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
if
param
is
not
None
:
...
...
lightx2v/common/offload/manager.py
View file @
74eeb429
import
gc
import
queue
import
threading
import
time
from
collections
import
OrderedDict
import
torch
import
torch
from
loguru
import
logger
from
packaging.version
import
parse
from
packaging.version
import
parse
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
WeightAsyncStreamManager
(
object
):
class
WeightAsyncStreamManager
(
object
):
def
__init__
(
self
,
offload_granularity
):
def
__init__
(
self
,
offload_granularity
):
self
.
offload_granularity
=
offload_granularity
self
.
offload_granularity
=
offload_granularity
self
.
init_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
init_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
need_init_first_buffer
=
True
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
if
torch_version
>=
parse
(
"2.7"
):
if
AI_DEVICE
==
"cuda"
and
torch_version
>=
parse
(
"2.7"
):
self
.
cuda_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
1
)
self
.
cuda_load_stream
=
torch
_device_module
.
Stream
(
priority
=
1
)
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=
1
)
self
.
compute_stream
=
torch
_device_module
.
Stream
(
priority
=
1
)
else
:
else
:
self
.
cuda_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
cuda_load_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=-
1
)
self
.
compute_stream
=
torch_device_module
.
Stream
(
priority
=-
1
)
def
init_cpu_buffer
(
self
,
blocks_cpu_buffer
=
None
,
phases_cpu_buffer
=
None
):
self
.
need_init_first_buffer
=
True
if
self
.
offload_granularity
==
"block"
:
assert
blocks_cpu_buffer
is
not
None
self
.
cpu_buffers
=
[
blocks_cpu_buffer
[
i
]
for
i
in
range
(
len
(
blocks_cpu_buffer
))]
elif
self
.
offload_granularity
==
"phase"
:
assert
phases_cpu_buffer
is
not
None
self
.
cpu_buffers
=
[
phases_cpu_buffer
[
i
]
for
i
in
range
(
len
(
phases_cpu_buffer
))]
else
:
raise
NotImplementedError
def
init_cuda_buffer
(
self
,
blocks_cuda_buffer
=
None
,
phases_cuda_buffer
=
None
):
def
init_cuda_buffer
(
self
,
blocks_cuda_buffer
=
None
,
phases_cuda_buffer
=
None
):
self
.
need_init_first_buffer
=
True
if
self
.
offload_granularity
==
"block"
:
if
self
.
offload_granularity
==
"block"
:
assert
blocks_cuda_buffer
is
not
None
assert
blocks_cuda_buffer
is
not
None
self
.
cuda_buffers
=
[
blocks_cuda_buffer
[
i
]
for
i
in
range
(
len
(
blocks_cuda_buffer
))]
self
.
cuda_buffers
=
[
blocks_cuda_buffer
[
i
]
for
i
in
range
(
len
(
blocks_cuda_buffer
))]
...
@@ -32,17 +42,32 @@ class WeightAsyncStreamManager(object):
...
@@ -32,17 +42,32 @@ class WeightAsyncStreamManager(object):
raise
NotImplementedError
raise
NotImplementedError
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
if
self
.
offload_granularity
==
"block"
:
with
torch_device_module
.
stream
(
self
.
init_stream
):
with
torch
.
cuda
.
stream
(
self
.
init_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
else
:
else
:
with
torch
.
cuda
.
stream
(
self
.
init_stream
):
if
self
.
offload_granularity
==
"block"
:
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
compute_phases
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
else
:
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
compute_phases
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
init_stream
.
synchronize
()
self
.
init_stream
.
synchronize
()
self
.
need_init_first_buffer
=
False
def
prefetch_weights
(
self
,
block_idx
,
blocks
,
adapter_block_idx
=
None
):
def
prefetch_weights
(
self
,
block_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
self
.
cuda_buffers
[
1
].
load_state_dict
(
blocks
[
block_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cpu_buffers
[
1
].
load_state_dict_from_disk
(
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
1
].
load_state_dict
(
self
.
cpu_buffers
[
1
].
state_dict
(),
block_idx
,
adapter_block_idx
)
else
:
self
.
cuda_buffers
[
1
].
load_state_dict
(
blocks
[
block_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cpu_buffers
[
phase_idx
].
load_state_dict_from_disk
(
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
else
:
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
def
swap_blocks
(
self
):
def
swap_blocks
(
self
):
self
.
cuda_load_stream
.
synchronize
()
self
.
cuda_load_stream
.
synchronize
()
...
@@ -52,347 +77,6 @@ class WeightAsyncStreamManager(object):
...
@@ -52,347 +77,6 @@ class WeightAsyncStreamManager(object):
self
.
cuda_buffers
[
0
],
self
.
cuda_buffers
[
0
],
)
)
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
def
swap_phases
(
self
):
def
swap_phases
(
self
):
self
.
cuda_load_stream
.
synchronize
()
self
.
cuda_load_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
class
LazyWeightAsyncStreamManager
(
WeightAsyncStreamManager
):
def
__init__
(
self
,
blocks_num
,
offload_ratio
=
1
,
phases_num
=
1
,
num_disk_workers
=
1
,
max_memory
=
2
,
offload_gra
=
"phase"
,
):
super
().
__init__
(
blocks_num
,
offload_ratio
,
phases_num
)
self
.
offload_gra
=
offload_gra
self
.
worker_stop_event
=
threading
.
Event
()
self
.
pin_memory_buffer
=
MemoryBuffer
(
max_memory
*
(
1024
**
3
))
self
.
disk_task_queue
=
queue
.
PriorityQueue
()
self
.
disk_workers
=
[]
self
.
release_workers
=
[]
self
.
_start_disk_workers
(
num_disk_workers
)
self
.
initial_prefetch_done
=
False
self
.
pending_tasks
=
{}
self
.
task_lock
=
threading
.
Lock
()
self
.
last_used_time
=
{}
self
.
time_lock
=
threading
.
Lock
()
def
_start_disk_workers
(
self
,
num_workers
):
for
i
in
range
(
num_workers
):
if
self
.
offload_gra
==
"phase"
:
worker
=
threading
.
Thread
(
target
=
self
.
_disk_worker_loop
,
daemon
=
True
)
else
:
worker
=
threading
.
Thread
(
target
=
self
.
_disk_worker_loop_block
,
daemon
=
True
)
worker
.
start
()
self
.
disk_workers
.
append
(
worker
)
def
_disk_worker_loop
(
self
):
while
not
self
.
worker_stop_event
.
is_set
():
try
:
_
,
task
=
self
.
disk_task_queue
.
get
(
timeout
=
0.5
)
if
task
is
None
:
break
block_idx
,
phase_idx
,
phase
=
task
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
((
block_idx
,
phase_idx
),
phase
)
with
self
.
task_lock
:
if
(
block_idx
,
phase_idx
)
in
self
.
pending_tasks
:
del
self
.
pending_tasks
[(
block_idx
,
phase_idx
)]
except
queue
.
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
f
"Disk worker thread error:
{
e
}
"
)
def
_disk_worker_loop_block
(
self
):
while
not
self
.
worker_stop_event
.
is_set
():
try
:
_
,
task
=
self
.
disk_task_queue
.
get
(
timeout
=
0.5
)
if
task
is
None
:
break
block_idx
,
block
=
task
for
phase
in
block
.
compute_phases
:
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
(
block_idx
,
block
)
with
self
.
task_lock
:
if
block_idx
in
self
.
pending_tasks
:
del
self
.
pending_tasks
[
block_idx
]
except
queue
.
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
f
"Disk worker thread error:
{
e
}
"
)
def
_async_prefetch_block
(
self
,
blocks
,
next_block_idx
=
None
):
if
next_block_idx
is
None
:
next_block_idx
=
self
.
pin_memory_buffer
.
get_max_block_index
()
if
next_block_idx
<
0
:
next_block_idx
=
0
if
next_block_idx
==
self
.
blocks_num
:
return
if
self
.
offload_gra
==
"phase"
:
for
phase_idx
in
range
(
self
.
phases_num
):
obj_key
=
(
next_block_idx
,
phase_idx
)
if
self
.
pin_memory_buffer
.
exists
(
obj_key
)
or
(
obj_key
in
self
.
pending_tasks
):
continue
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
phase
=
blocks
[
next_block_idx
].
compute_phases
[
phase_idx
]
priority_key
=
(
next_block_idx
,
phase_idx
)
self
.
disk_task_queue
.
put
((
priority_key
,
(
next_block_idx
,
phase_idx
,
phase
)))
else
:
obj_key
=
next_block_idx
if
self
.
pin_memory_buffer
.
exists
(
obj_key
)
or
(
obj_key
in
self
.
pending_tasks
):
return
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
block
=
blocks
[
next_block_idx
]
self
.
disk_task_queue
.
put
((
obj_key
,
(
next_block_idx
,
block
)))
def
_sync_prefetch_block
(
self
,
blocks
):
block_idx
=
0
while
not
self
.
pin_memory_buffer
.
is_nearly_full
():
if
self
.
offload_gra
==
"phase"
:
for
phase_idx
in
range
(
self
.
phases_num
):
phase
=
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
((
block_idx
,
phase_idx
),
phase
)
else
:
block
=
blocks
[
block_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
"
)
for
phase
in
block
.
compute_phases
:
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
(
block_idx
,
block
)
block_idx
+=
1
if
block_idx
==
self
.
blocks_num
:
break
def
prefetch_weights_from_disk
(
self
,
blocks
):
if
self
.
initial_prefetch_done
:
return
self
.
_sync_prefetch_block
(
blocks
)
self
.
initial_prefetch_done
=
True
def
prefetch_weights
(
self
,
block_idx
,
blocks
):
obj_key
=
block_idx
if
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
is_loading
=
False
with
self
.
task_lock
:
if
obj_key
in
self
.
pending_tasks
:
is_loading
=
True
if
is_loading
:
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
"
)
else
:
logger
.
info
(
"Not find prefetch block={block_idx} task."
)
logger
.
info
(
"Sync prefetch block={block_idx}."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
for
phase_idx
in
self
.
phases_num
:
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
15
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
block
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
block
.
to_cuda_async
()
self
.
cuda_buffers
[
2
]
=
(
obj_key
,
block
)
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
if
block_idx
<
self
.
offload_blocks_num
:
if
self
.
cuda_buffers
[
1
]
is
not
None
:
old_key
,
old_block
=
self
.
cuda_buffers
[
1
]
if
self
.
pin_memory_buffer
.
exists
(
old_key
):
old_block
.
to_cpu_async
()
self
.
pin_memory_buffer
.
pop
(
old_key
)
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
):
obj_key
=
(
block_idx
,
phase_idx
)
if
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
is_loading
=
False
with
self
.
task_lock
:
if
obj_key
in
self
.
pending_tasks
:
is_loading
=
True
if
is_loading
:
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
else
:
logger
.
info
(
f
"Not find block=
{
block_idx
}
, phase=
{
phase_idx
}
task."
)
logger
.
info
(
f
"Sync prefetch block=
{
block_idx
}
, phase=
{
phase_idx
}
."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
phase
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda_async
()
self
.
cuda_buffers
[
2
]
=
(
obj_key
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
if
block_idx
*
self
.
phases_num
+
phase_idx
<
self
.
offload_phases_num
:
if
self
.
cuda_buffers
[
1
]
is
not
None
:
old_key
,
old_phase
=
self
.
cuda_buffers
[
1
]
if
self
.
pin_memory_buffer
.
exists
(
old_key
):
old_phase
.
to_cpu_async
()
self
.
pin_memory_buffer
.
pop
(
old_key
)
def
shutdown
(
self
):
self
.
worker_stop_event
.
set
()
while
not
self
.
disk_task_queue
.
empty
():
try
:
self
.
disk_task_queue
.
get_nowait
()
except
queue
.
Empty
:
continue
for
_
in
self
.
disk_workers
:
self
.
disk_task_queue
.
put
((
0
,
None
))
for
worker
in
self
.
disk_workers
:
worker
.
join
(
timeout
=
5
)
for
worker
in
self
.
release_workers
:
worker
.
join
(
timeout
=
5
)
logger
.
info
(
"All worker threads have been closed"
)
def
clear
(
self
):
self
.
pin_memory_buffer
.
clear
()
self
.
shutdown
()
class
MemoryBuffer
:
def
__init__
(
self
,
max_memory_bytes
=
8
*
(
1024
**
3
)):
self
.
cache
=
OrderedDict
()
self
.
max_mem
=
max_memory_bytes
self
.
used_mem
=
0
self
.
obj_size_map
=
{}
self
.
lock
=
threading
.
Lock
()
self
.
insertion_order
=
[]
self
.
insertion_index
=
0
def
push
(
self
,
key
,
obj
):
with
self
.
lock
:
if
key
in
self
.
cache
:
return
if
hasattr
(
obj
,
"compute_phases"
):
obj_idx
=
key
if
len
(
self
.
obj_size_map
)
==
0
:
_size
=
0
for
phase
in
obj
.
compute_phases
:
_size
+=
phase
.
calculate_size
()
self
.
obj_size_map
[
0
]
=
_size
size
=
self
.
obj_size_map
[
0
]
else
:
_
,
obj_idx
=
key
if
obj_idx
not
in
self
.
obj_size_map
:
self
.
obj_size_map
[
obj_idx
]
=
obj
.
calculate_size
()
size
=
self
.
obj_size_map
[
obj_idx
]
self
.
cache
[
key
]
=
(
size
,
obj
,
self
.
insertion_index
)
self
.
insertion_order
.
append
((
key
,
self
.
insertion_index
))
self
.
insertion_index
+=
1
self
.
used_mem
+=
size
def
_remove_key
(
self
,
key
):
if
key
in
self
.
cache
:
size
,
obj
,
idx
=
self
.
cache
.
pop
(
key
)
try
:
if
hasattr
(
obj
,
"compute_phases"
):
for
phase
in
obj
.
compute_phases
:
phase
.
clear
()
else
:
obj
.
clear
()
except
Exception
as
e
:
logger
.
info
(
f
"Error clearing obj:
{
e
}
"
)
self
.
used_mem
-=
size
self
.
insertion_order
=
[(
k
,
i
)
for
(
k
,
i
)
in
self
.
insertion_order
if
k
!=
key
]
def
get
(
self
,
key
,
default
=
None
):
with
self
.
lock
:
if
key
in
self
.
cache
:
size
,
obj
,
idx
=
self
.
cache
[
key
]
return
obj
return
default
def
exists
(
self
,
key
):
with
self
.
lock
:
return
key
in
self
.
cache
def
pop_front
(
self
):
with
self
.
lock
:
if
not
self
.
insertion_order
:
return
False
front_key
,
_
=
self
.
insertion_order
[
0
]
self
.
_remove_key
(
front_key
)
return
True
def
pop
(
self
,
key
):
with
self
.
lock
:
if
key
in
self
.
cache
:
self
.
_remove_key
(
key
)
return
True
return
False
def
is_nearly_full
(
self
):
with
self
.
lock
:
return
self
.
used_mem
>=
self
.
max_mem
*
0.9
def
get_max_block_index
(
self
):
with
self
.
lock
:
if
not
self
.
cache
:
return
-
1
if
isinstance
(
list
(
self
.
cache
.
keys
())[
-
1
],
tuple
):
return
(
list
(
self
.
cache
.
keys
())[
-
1
][
0
]
+
1
)
%
40
else
:
return
(
list
(
self
.
cache
.
keys
())[
-
1
]
+
1
)
%
40
def
clear
(
self
):
with
self
.
lock
:
for
key
in
list
(
self
.
cache
.
keys
()):
self
.
_remove_key
(
key
)
self
.
insertion_order
=
[]
self
.
insertion_index
=
0
self
.
used_mem
=
0
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
lightx2v/common/ops/attn/flash_attn.py
View file @
74eeb429
...
@@ -73,9 +73,10 @@ class FlashAttn3Weight(AttnWeightTemplate):
...
@@ -73,9 +73,10 @@ class FlashAttn3Weight(AttnWeightTemplate):
bs
=
1
bs
=
1
elif
len
(
q
.
shape
)
==
4
:
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
bs
=
q
.
shape
[
0
]
q
=
q
.
reshape
(
-
1
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
])
if
model_cls
is
not
None
and
model_cls
in
[
"hunyuan_video_1.5"
]:
k
=
k
.
reshape
(
-
1
,
k
.
shape
[
-
2
],
k
.
shape
[
-
1
])
q
=
q
.
reshape
(
-
1
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
])
v
=
v
.
reshape
(
-
1
,
v
.
shape
[
-
2
],
v
.
shape
[
-
1
])
k
=
k
.
reshape
(
-
1
,
k
.
shape
[
-
2
],
k
.
shape
[
-
1
])
v
=
v
.
reshape
(
-
1
,
v
.
shape
[
-
2
],
v
.
shape
[
-
1
])
x
=
flash_attn_varlen_func_v3
(
x
=
flash_attn_varlen_func_v3
(
q
,
q
,
k
,
k
,
...
...
lightx2v/common/ops/attn/template.py
View file @
74eeb429
...
@@ -30,3 +30,6 @@ class AttnWeightTemplate(metaclass=ABCMeta):
...
@@ -30,3 +30,6 @@ class AttnWeightTemplate(metaclass=ABCMeta):
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_inde
=
None
):
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_inde
=
None
):
return
{}
return
{}
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_inde
=
None
):
pass
lightx2v/common/ops/conv/conv2d.py
View file @
74eeb429
...
@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
...
@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
import
torch
import
torch
from
lightx2v.utils.registry_factory
import
CONV2D_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
CONV2D_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
Conv2dWeightTemplate
(
metaclass
=
ABCMeta
):
class
Conv2dWeightTemplate
(
metaclass
=
ABCMeta
):
...
@@ -34,8 +35,8 @@ class Conv2dWeight(Conv2dWeightTemplate):
...
@@ -34,8 +35,8 @@ class Conv2dWeight(Conv2dWeightTemplate):
super
().
__init__
(
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
)
super
().
__init__
(
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
)
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
(
)
if
self
.
bias_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
].
to
(
AI_DEVICE
)
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv2d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
input_tensor
=
torch
.
nn
.
functional
.
conv2d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
...
@@ -47,9 +48,9 @@ class Conv2dWeight(Conv2dWeightTemplate):
...
@@ -47,9 +48,9 @@ class Conv2dWeight(Conv2dWeightTemplate):
self
.
bias
=
self
.
bias
.
cpu
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
bias
.
cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
if
destination
is
None
:
...
@@ -58,10 +59,3 @@ class Conv2dWeight(Conv2dWeightTemplate):
...
@@ -58,10 +59,3 @@ class Conv2dWeight(Conv2dWeightTemplate):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
return
destination
return
destination
def
clear
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
lightx2v/common/ops/conv/conv3d.py
View file @
74eeb429
...
@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
...
@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
import
torch
import
torch
from
lightx2v.utils.registry_factory
import
CONV3D_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
CONV3D_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
Conv3dWeightTemplate
(
metaclass
=
ABCMeta
):
class
Conv3dWeightTemplate
(
metaclass
=
ABCMeta
):
...
@@ -70,9 +71,9 @@ class Conv3dWeight(Conv3dWeightTemplate):
...
@@ -70,9 +71,9 @@ class Conv3dWeight(Conv3dWeightTemplate):
return
input_tensor
return
input_tensor
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
if
hasattr
(
self
,
"pin_weight"
):
...
@@ -91,10 +92,3 @@ class Conv3dWeight(Conv3dWeightTemplate):
...
@@ -91,10 +92,3 @@ class Conv3dWeight(Conv3dWeightTemplate):
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
# .cpu().detach().clone()
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
# .cpu().detach().clone()
return
destination
return
destination
def
clear
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
lightx2v/common/ops/embedding/embedding_weight.py
View file @
74eeb429
...
@@ -5,12 +5,14 @@ import torch
...
@@ -5,12 +5,14 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lightx2v.utils.registry_factory
import
EMBEDDING_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
EMBEDDING_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
EmbeddingWeightTemplate
(
metaclass
=
ABCMeta
):
class
EmbeddingWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
weight_name
=
weight_name
self
.
weight_name
=
weight_name
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
is_post_adapter
=
is_post_adapter
...
@@ -19,7 +21,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
...
@@ -19,7 +21,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
(
)
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
)
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
...
@@ -32,7 +34,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
...
@@ -32,7 +34,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
if
hasattr
(
self
,
"pin_weight"
):
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
74eeb429
...
@@ -9,6 +9,7 @@ from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tens
...
@@ -9,6 +9,7 @@ from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tens
from
lightx2v.utils.global_paras
import
CALIB
from
lightx2v.utils.global_paras
import
CALIB
from
lightx2v.utils.quant_utils
import
FloatQuantizer
,
IntegerQuantizer
from
lightx2v.utils.quant_utils
import
FloatQuantizer
,
IntegerQuantizer
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
try
:
from
lightx2v_kernel.gemm
import
(
from
lightx2v_kernel.gemm
import
(
...
@@ -69,10 +70,11 @@ except ImportError:
...
@@ -69,10 +70,11 @@ except ImportError:
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
weight_name
=
weight_name
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
bias_name
=
bias_name
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
is_post_adapter
=
is_post_adapter
...
@@ -90,11 +92,11 @@ class MMWeightTemplate(metaclass=ABCMeta):
...
@@ -90,11 +92,11 @@ class MMWeightTemplate(metaclass=ABCMeta):
self
.
config
=
config
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight_scale
=
self
.
pin_weight_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
if
hasattr
(
self
,
"pin_weight"
):
...
@@ -113,44 +115,63 @@ class MMWeightTemplate(metaclass=ABCMeta):
...
@@ -113,44 +115,63 @@ class MMWeightTemplate(metaclass=ABCMeta):
@
MM_WEIGHT_REGISTER
(
"Default"
)
@
MM_WEIGHT_REGISTER
(
"Default"
)
class
MMWeight
(
MMWeightTemplate
):
class
MMWeight
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
_load_cuda_buffers
(
weight_dict
)
if
self
.
bias_name
is
not
None
:
elif
self
.
create_cpu_buffer
:
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
_load_cpu_pin_buffers
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
if
self
.
lazy_load
:
return
self
.
lazy_load_file
.
get_tensor
(
source_name
)
return
weight_dict
[
source_name
]
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
=
pin_tensor
.
copy_
(
tensor
)
if
transpose
:
pin_tensor
=
pin_tensor
.
t
()
del
tensor
return
pin_tensor
def
_load_cuda_buffers
(
self
,
weight_dict
):
self
.
weight_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
weight_name
,
weight_dict
).
t
().
to
(
AI_DEVICE
)
if
self
.
bias_name
is
not
None
:
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_tensor
=
weight_dict
[
self
.
weight_name
]
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
]).
t
()
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_tensor
=
weight_dict
[
self
.
bias_name
]
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
...
@@ -169,6 +190,28 @@ class MMWeight(MMWeightTemplate):
...
@@ -169,6 +190,28 @@ class MMWeight(MMWeightTemplate):
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
return
destination
return
destination
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
assert
adapter_block_index
is
not
None
...
@@ -195,19 +238,20 @@ class MMWeight(MMWeightTemplate):
...
@@ -195,19 +238,20 @@ class MMWeight(MMWeightTemplate):
@
MM_WEIGHT_REGISTER
(
"Default-Force-FP32"
)
@
MM_WEIGHT_REGISTER
(
"Default-Force-FP32"
)
class
MMWeightForceFP32
(
MMWeight
):
class
MMWeightForceFP32
(
MMWeight
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
super
().
load
(
weight_dict
)
if
not
self
.
lazy_load
:
self
.
weight
=
self
.
weight
.
to
(
torch
.
float32
)
super
().
load
(
weight_dict
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
weight
=
self
.
weight
.
to
(
torch
.
float32
)
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
class
MMWeightQuantTemplate
(
MMWeightTemplate
):
class
MMWeightQuantTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
weight_scale_name
=
self
.
weight_name
.
removesuffix
(
".weight"
)
+
".weight_scale"
self
.
weight_scale_name
=
self
.
weight_name
.
removesuffix
(
".weight"
)
+
".weight_scale"
self
.
load_func
=
None
self
.
load_func
=
None
self
.
weight_need_transpose
=
True
self
.
weight_need_transpose
=
True
...
@@ -215,87 +259,133 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -215,87 +259,133 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
bias_force_fp32
=
False
# =========================
# =========================
# weight load functions
# weight load functions
# =========================
# =========================
def
load
(
self
,
weight_dict
):
self
.
load_quantized
(
weight_dict
)
if
self
.
weight_need_transpose
:
if
hasattr
(
self
,
"weight"
)
and
self
.
weight
is
not
None
:
self
.
weight
=
self
.
weight
.
t
()
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
self
.
pin_weight
=
self
.
pin_weight
.
t
()
if
hasattr
(
self
,
"weight_cuda_buffer"
)
and
self
.
weight_cuda_buffer
is
not
None
:
self
.
weight_cuda_buffer
=
self
.
weight_cuda_buffer
.
t
()
def
load_from_disk
(
self
):
# Need Rewrite
def
load_quantized
(
self
,
weight_dict
):
if
not
torch
.
_dynamo
.
is_compiling
():
if
self
.
create_cuda_buffer
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
pin_memory
()
self
.
_load_cuda_buffers
(
weight_dict
)
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
().
pin_memory
()
elif
self
.
create_cpu_buffer
:
if
self
.
bias_name
is
not
None
:
self
.
_load_cpu_pin_buffers
()
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
).
pin_memory
()
else
:
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
_load_default_tensors
(
weight_dict
)
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
if
self
.
weight_need_transpose
:
def
_load_cuda_buffers
(
self
,
weight_dict
):
self
.
weight
=
self
.
weight
.
t
()
source
=
self
.
lazy_load_file
if
self
.
lazy_load
else
weight_dict
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
def
load
(
self
,
weight_dict
):
def
_get_cuda_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
weight
=
source
.
get_tensor
(
self
.
weight_name
).
to
(
AI_DEVICE
)
scale
=
source
.
get_tensor
(
self
.
weight_scale_name
).
float
().
to
(
AI_DEVICE
)
else
:
weight
=
source
[
self
.
weight_name
].
to
(
AI_DEVICE
)
scale
=
source
[
self
.
weight_scale_name
].
float
().
to
(
AI_DEVICE
)
return
weight
,
scale
def
_get_cuda_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
return
None
if
is_lazy
:
bias
=
source
.
get_tensor
(
self
.
bias_name
)
dtype
=
self
.
infer_dtype
else
:
bias
=
source
[
self
.
bias_name
]
dtype
=
bias
.
dtype
if
self
.
bias_force_fp32
:
bias
=
bias
.
to
(
torch
.
float32
)
else
:
bias
=
bias
.
to
(
dtype
)
return
bias
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
self
.
pin_weight
,
self
.
pin_weight_scale
=
self
.
_get_cpu_pin_tensor_pair
(
self
.
lazy_load_file
,
is_lazy
=
True
)
self
.
pin_bias
=
self
.
_get_cpu_pin_bias_tensor
(
self
.
lazy_load_file
,
is_lazy
=
True
)
self
.
bias
=
None
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_dtype
=
torch
.
float
else
:
weight_tensor
=
source
[
self
.
weight_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
return
pin_weight
,
pin_scale
def
_get_cpu_pin_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
return
None
if
is_lazy
:
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
if
not
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
else
:
bias_tensor
=
source
[
self
.
bias_name
]
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
def
_create_pin_tensor
(
self
,
tensor
,
dtype
=
None
):
dtype
=
dtype
or
tensor
.
dtype
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
not
self
.
lazy_load
:
self
.
load_func
(
weight_dict
)
self
.
weight
,
self
.
weight_scale
,
self
.
pin_weight
,
self
.
pin_weight_scale
=
self
.
_get_device_tensor_pair
(
weight_dict
)
if
self
.
weight_need_transpose
:
self
.
_load_default_bias
(
weight_dict
)
if
hasattr
(
self
,
"weight"
):
else
:
self
.
weight
=
self
.
weight
.
t
()
self
.
bias
=
None
if
hasattr
(
self
,
"pin_weight"
):
self
.
pin_bias
=
None
self
.
pin_weight
=
self
.
pin_weight
.
t
()
if
hasattr
(
self
,
"weight_cuda_buffer"
):
self
.
weight_cuda_buffer
=
self
.
weight_cuda_buffer
.
t
()
def
clear
(
self
):
attrs
=
[
"weight"
,
"weight_scale"
,
"bias"
,
"pin_weight"
,
"pin_weight_scale"
,
"pin_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
def
load_quantized
(
self
,
weight_dict
):
def
_get_device_tensor_pair
(
self
,
source
):
if
self
.
create_cuda_buffer
:
device
=
source
[
self
.
weight_name
].
device
# move to cuda buffer
if
device
.
type
==
"cpu"
:
self
.
weight
_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
(
)
pin_
weight
,
pin_scale
=
self
.
_get_cpu_pin_tensor_pair
(
source
,
is_lazy
=
False
)
self
.
weight_scale_cuda_buffer
=
weight_dict
[
self
.
weight_scale
_name
].
float
().
cuda
()
return
None
,
None
,
pin_
weight
,
pin
_scale
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
return
source
[
self
.
weight_name
],
source
[
self
.
weight_scale_name
].
float
(),
None
,
None
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
def
_load_default_bias
(
self
,
source
):
weight_scale_dtype
=
torch
.
float
if
self
.
bias_name
is
None
:
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
bias
=
None
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
self
.
bias_cuda_buffer
=
None
else
:
return
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
is_lazy
=
False
)
# move to cuda buffer
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
pin_bias
=
None
else
:
bias_tensor
=
source
[
self
.
bias_name
].
float
()
if
self
.
bias_force_fp32
else
source
[
self
.
bias_name
]
device
=
bias_tensor
.
device
if
device
.
type
==
"cpu"
:
self
.
pin_bias
=
self
.
_get_cpu_pin_bias_tensor
(
source
,
is_lazy
=
False
)
self
.
bias
=
None
else
:
self
.
bias
=
bias_tensor
self
.
pin_bias
=
None
def
load_fp8_perchannel_sym
(
self
,
weight_dict
):
def
load_fp8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -320,7 +410,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -320,7 +410,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
load_mxfp4
(
self
,
weight_dict
):
def
load_mxfp4
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
).
to
(
torch
.
bfloat16
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp4_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp4_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
...
@@ -343,7 +433,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -343,7 +433,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
load_mxfp6
(
self
,
weight_dict
):
def
load_mxfp6
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
).
to
(
torch
.
bfloat16
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp6_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp6_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
...
@@ -366,7 +456,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -366,7 +456,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
load_mxfp8
(
self
,
weight_dict
):
def
load_mxfp8
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
).
to
(
torch
.
bfloat16
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp8_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp8_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
else
:
...
@@ -424,19 +514,16 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -424,19 +514,16 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
# move to cuda buffer
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
to
(
AI_DEVICE
)
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
if
device
.
type
==
"cpu"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
pin_bias
=
None
...
@@ -548,6 +635,36 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -548,6 +635,36 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
if
self
.
weight_need_transpose
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
weight_scale_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
MM_WEIGHT_REGISTER
(
"fp8-vllm"
)
@
MM_WEIGHT_REGISTER
(
"fp8-vllm"
)
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeightQuantTemplate
):
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeightQuantTemplate
):
...
@@ -560,8 +677,8 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
...
@@ -560,8 +677,8 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Kernel: vllm
Kernel: vllm
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
...
@@ -595,8 +712,8 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
...
@@ -595,8 +712,8 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Kernel: vllm
Kernel: vllm
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
...
@@ -605,7 +722,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
...
@@ -605,7 +722,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
device
=
input_tensor
.
device
output_tensor
=
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
...
@@ -629,8 +746,8 @@ class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
...
@@ -629,8 +746,8 @@ class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
Act: mxfp4
Act: mxfp4
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp4
self
.
load_func
=
self
.
load_mxfp4
self
.
weight_need_transpose
=
False
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp4
self
.
act_quant_func
=
self
.
act_quant_mxfp4
...
@@ -656,8 +773,8 @@ class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
...
@@ -656,8 +773,8 @@ class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
Act: mxfp8
Act: mxfp8
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp6
self
.
load_func
=
self
.
load_mxfp6
self
.
weight_need_transpose
=
False
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
self
.
act_quant_func
=
self
.
act_quant_mxfp8
...
@@ -683,8 +800,8 @@ class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
...
@@ -683,8 +800,8 @@ class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
Act: mxfp8
Act: mxfp8
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp8
self
.
load_func
=
self
.
load_mxfp8
self
.
weight_need_transpose
=
False
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
self
.
act_quant_func
=
self
.
act_quant_mxfp8
...
@@ -710,8 +827,8 @@ class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
...
@@ -710,8 +827,8 @@ class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
Act: nvfp4
Act: nvfp4
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_nvfp4
self
.
load_func
=
self
.
load_nvfp4
self
.
weight_need_transpose
=
False
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_nvfp4
self
.
act_quant_func
=
self
.
act_quant_nvfp4
...
@@ -722,13 +839,13 @@ class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
...
@@ -722,13 +839,13 @@ class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
return
output_tensor
return
output_tensor
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight_scale
=
self
.
pin_weight_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
pin_alpha
.
cuda
(
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
pin_alpha
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
if
hasattr
(
self
,
"pin_weight"
):
...
@@ -758,8 +875,8 @@ class MMCalibNvfp4(MMWeight):
...
@@ -758,8 +875,8 @@ class MMCalibNvfp4(MMWeight):
absmax: torch.max(torch.abs(input_tensor))
absmax: torch.max(torch.abs(input_tensor))
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
running_absmax
=
None
self
.
running_absmax
=
None
self
.
count
=
0
self
.
count
=
0
self
.
decay
=
0.9
self
.
decay
=
0.9
...
@@ -794,11 +911,12 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
...
@@ -794,11 +911,12 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
Kernel: Q8F
Kernel: Q8F
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
self
.
bias_force_fp32
=
True
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
...
@@ -824,8 +942,8 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
...
@@ -824,8 +942,8 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Kernel: Q8F
Kernel: Q8F
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
...
@@ -855,8 +973,8 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuant
...
@@ -855,8 +973,8 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuant
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perblock128_sym
self
.
load_func
=
self
.
load_fp8_perblock128_sym
self
.
weight_need_transpose
=
False
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannelgroup128_sym_sgl
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannelgroup128_sym_sgl
...
@@ -889,8 +1007,8 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
...
@@ -889,8 +1007,8 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
Kernel: Sgl-kernel
Kernel: Sgl-kernel
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_sgl
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_sgl
...
@@ -903,7 +1021,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
...
@@ -903,7 +1021,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
input_tensor_scale
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
weight_scale
,
self
.
infer_dtype
,
self
.
infer_dtype
,
bias
=
self
.
bias
,
self
.
bias
if
self
.
bias
is
not
None
else
None
,
)
)
return
output_tensor
return
output_tensor
...
@@ -919,8 +1037,8 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
...
@@ -919,8 +1037,8 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
...
@@ -944,7 +1062,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
...
@@ -944,7 +1062,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
@
MM_WEIGHT_REGISTER
(
"int8-torchao"
)
@
MM_WEIGHT_REGISTER
(
"int8-torchao"
)
class
MMWeightWint8channelAint8channeldynamic
SglActVllm
(
MMWeightQuantTemplate
):
class
MMWeightWint8channelAint8channeldynamic
Torchao
(
MMWeightQuantTemplate
):
"""
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
...
@@ -954,8 +1072,8 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
...
@@ -954,8 +1072,8 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
Kernel: Torchao
Kernel: Torchao
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_torchao
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_torchao
...
@@ -971,33 +1089,34 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
...
@@ -971,33 +1089,34 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
class
MMWeightGGUFTemplate
(
MMWeightTemplate
):
class
MMWeightGGUFTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
assert
not
self
.
create_cuda_buffer
,
"GGUF Unsupported offload block"
if
not
self
.
lazy_load
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
assert
not
self
.
create_cuda_buffer
,
"GGUF Unsupported offload block"
self
.
weight
=
weight_dict
[
self
.
weight_name
]
weight_shape
=
self
.
weight
.
shape
weight_shape
=
self
.
weight
.
shape
weight_dtype
=
self
.
weight
.
dtype
weight_dtype
=
self
.
weight
.
dtype
if
isinstance
(
self
.
weight
,
GGMLTensor
):
if
isinstance
(
self
.
weight
,
GGMLTensor
):
self
.
pin_weight
=
GGMLTensor
.
empty_pinned
(
weight_shape
,
orig_shape
=
self
.
weight
.
orig_shape
,
dtype
=
weight_dtype
,
gguf_type
=
self
.
weight
.
gguf_type
)
self
.
pin_weight
=
GGMLTensor
.
empty_pinned
(
weight_shape
,
orig_shape
=
self
.
weight
.
orig_shape
,
dtype
=
weight_dtype
,
gguf_type
=
self
.
weight
.
gguf_type
)
self
.
pin_weight
.
copy_from
(
self
.
weight
)
self
.
pin_weight
.
copy_from
(
self
.
weight
)
else
:
else
:
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
isinstance
(
self
.
bias
,
GGMLTensor
):
if
isinstance
(
self
.
bias
,
GGMLTensor
):
self
.
pin_bias
=
GGMLTensor
.
empty_pinned
(
self
.
bias
.
shape
,
orig_shape
=
self
.
bias
.
orig_shape
,
dtype
=
self
.
bias
.
dtype
,
gguf_type
=
self
.
bias
.
gguf_type
)
self
.
pin_bias
=
GGMLTensor
.
empty_pinned
(
self
.
bias
.
shape
,
orig_shape
=
self
.
bias
.
orig_shape
,
dtype
=
self
.
bias
.
dtype
,
gguf_type
=
self
.
bias
.
gguf_type
)
self
.
pin_bias
.
copy_from
(
self
.
bias
)
self
.
pin_bias
.
copy_from
(
self
.
bias
)
else
:
self
.
pin_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
self
.
pin_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
self
.
bias
=
None
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
None
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
if
self
.
is_post_adapter
:
...
@@ -1035,9 +1154,7 @@ class MMWeightGGUFTemplate(MMWeightTemplate):
...
@@ -1035,9 +1154,7 @@ class MMWeightGGUFTemplate(MMWeightTemplate):
if
tensor
is
None
:
if
tensor
is
None
:
return
return
device
=
tensor
.
device
weight
=
gguf_dequantize_tensor
(
tensor
,
dtype
)
weight
=
gguf_dequantize_tensor
(
tensor
,
dtype
)
# prevent propagating custom tensor class
if
isinstance
(
weight
,
GGMLTensor
):
if
isinstance
(
weight
,
GGMLTensor
):
weight
=
torch
.
Tensor
(
weight
)
weight
=
torch
.
Tensor
(
weight
)
...
@@ -1135,8 +1252,8 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
...
@@ -1135,8 +1252,8 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
Kernel: Marlin
Kernel: Marlin
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_quantized
self
.
load_func
=
self
.
load_quantized
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
...
...
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
74eeb429
...
@@ -5,16 +5,18 @@ import torch
...
@@ -5,16 +5,18 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.triton_ops
import
norm_infer
from
.triton_ops
import
norm_infer
class
LNWeightTemplate
(
metaclass
=
ABCMeta
):
class
LNWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
bias_name
=
bias_name
self
.
eps
=
eps
self
.
eps
=
eps
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
is_post_adapter
=
is_post_adapter
...
@@ -23,53 +25,71 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -23,53 +25,71 @@ class LNWeightTemplate(metaclass=ABCMeta):
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffers
(
weight_dict
)
if
self
.
weight_name
is
not
None
:
elif
self
.
create_cpu_buffer
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
().
t
()
self
.
_load_cpu_pin_buffers
()
if
self
.
bias_name
is
not
None
:
else
:
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
_load_default_tensors
(
weight_dict
)
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
and
self
.
weight_name
is
not
None
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_cpu_pin_tensor
(
weight_tensor
)
bias_tensor
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
self
.
pin_bias
=
self
.
_create_cpu_pin_tensor
(
bias_tensor
)
if
bias_tensor
is
not
None
else
None
self
.
bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
if
self
.
weight_name
is
not
None
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
if
device
.
type
==
"cpu"
:
else
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
self
.
weight
=
None
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
bias
=
None
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
def
_get_tensor
(
self
,
name
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
name
is
None
:
if
self
.
bias_name
is
not
None
:
return
None
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
if
self
.
lazy_load
:
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
tensor
=
self
.
lazy_load_file
.
get_tensor
(
name
)
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
if
use_infer_dtype
:
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
self
.
bias
=
None
tensor
=
weight_dict
[
name
]
self
.
pin_bias
=
None
return
tensor
del
weight_dict
[
self
.
weight_name
]
else
:
def
_create_cpu_pin_tensor
(
self
,
tensor
):
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
tensor
is
None
:
if
self
.
bias_name
is
not
None
:
return
None
self
.
bias
=
weight_dict
[
self
.
bias_name
]
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
else
:
pin_tensor
.
copy_
(
tensor
)
self
.
bias
=
None
del
tensor
else
:
return
pin_tensor
self
.
weight
=
None
self
.
bias
=
None
def
_load_cuda_buffers
(
self
,
weight_dict
):
weight_tensor
=
self
.
_get_tensor
(
self
.
weight_name
,
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
def
_calculate_size
(
self
):
if
weight_tensor
is
not
None
:
if
self
.
weight
is
None
:
self
.
weight_cuda_buffer
=
weight_tensor
.
to
(
AI_DEVICE
)
return
0
if
self
.
bias
is
not
None
:
bias_tensor
=
self
.
_get_tensor
(
self
.
bias_name
,
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
if
bias_tensor
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
self
.
bias_cuda_buffer
=
bias_tensor
.
to
(
AI_DEVICE
)
def
clear
(
self
):
def
_load_cpu_pin_buffers
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
weight_tensor
=
self
.
_get_tensor
(
self
.
weight_name
,
use_infer_dtype
=
True
)
for
attr
in
attrs
:
if
weight_tensor
is
not
None
:
if
hasattr
(
self
,
attr
):
self
.
pin_weight
=
self
.
_create_cpu_pin_tensor
(
weight_tensor
)
delattr
(
self
,
attr
)
else
:
setattr
(
self
,
attr
,
None
)
self
.
weight
=
None
bias_tensor
=
self
.
_get_tensor
(
self
.
bias_name
,
use_infer_dtype
=
True
)
if
bias_tensor
is
not
None
:
self
.
pin_bias
=
self
.
_create_cpu_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
...
@@ -81,11 +101,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -81,11 +101,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
else
:
else
:
self
.
weight
=
None
self
.
weight
=
None
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -129,28 +149,33 @@ class LNWeightTemplate(metaclass=ABCMeta):
...
@@ -129,28 +149,33 @@ class LNWeightTemplate(metaclass=ABCMeta):
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
@
LN_WEIGHT_REGISTER
(
"Default"
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
if
self
.
weight_name
is
not
None
:
if
self
.
weight_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
()
:
if
self
.
is_post_adapter
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
(
)
self
.
weight
_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight
=
None
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
if
self
.
is_post_adapter
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
()).
pin_memory
()
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
else
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
())
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias
=
None
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
LN_WEIGHT_REGISTER
(
"Default"
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
...
@@ -169,25 +194,8 @@ class LNWeight(LNWeightTemplate):
...
@@ -169,25 +194,8 @@ class LNWeight(LNWeightTemplate):
@
LN_WEIGHT_REGISTER
(
"Triton"
)
@
LN_WEIGHT_REGISTER
(
"Triton"
)
class
LNWeight
(
LNWeightTemplate
):
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
if
self
.
weight_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
else
:
self
.
weight
=
None
if
self
.
bias_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
())
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
input_tensor
=
norm_infer
(
input_tensor
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
input_tensor
=
norm_infer
(
input_tensor
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
74eeb429
...
@@ -5,6 +5,7 @@ import torch
...
@@ -5,6 +5,7 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
try
:
import
sgl_kernel
import
sgl_kernel
...
@@ -13,10 +14,11 @@ except ImportError:
...
@@ -13,10 +14,11 @@ except ImportError:
class
RMSWeightTemplate
(
metaclass
=
ABCMeta
):
class
RMSWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
weight_name
=
weight_name
self
.
eps
=
eps
self
.
eps
=
eps
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
is_post_adapter
=
is_post_adapter
...
@@ -25,26 +27,45 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -25,26 +27,45 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
config
=
{}
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffer
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffer
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
if
device
.
type
==
"cpu"
:
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_cpu_pin_weight
(
weight_tensor
)
del
weight_dict
[
self
.
weight_name
]
else
:
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
if
self
.
lazy_load
:
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
if
use_infer_dtype
:
del
weight_dict
[
self
.
weight_name
]
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
tensor
=
weight_dict
[
self
.
weight_name
]
return
tensor
def
clear
(
self
):
attrs
=
[
"weight"
,
"pinned_weight"
]
def
_create_cpu_pin_weight
(
self
,
tensor
):
for
attr
in
attrs
:
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
if
hasattr
(
self
,
attr
):
pin_tensor
.
copy_
(
tensor
)
delattr
(
self
,
attr
)
del
tensor
setattr
(
self
,
attr
,
None
)
return
pin_tensor
def
_load_cuda_buffer
(
self
,
weight_dict
):
weight_tensor
=
self
.
_get_weight_tensor
(
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
self
.
weight_cuda_buffer
=
weight_tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffer
(
self
):
weight_tensor
=
self
.
_get_weight_tensor
(
use_infer_dtype
=
True
)
self
.
pin_weight
=
self
.
_create_cpu_pin_weight
(
weight_tensor
)
@
abstractmethod
@
abstractmethod
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
...
@@ -55,7 +76,7 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -55,7 +76,7 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
config
=
config
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
if
hasattr
(
self
,
"pin_weight"
):
...
@@ -63,31 +84,6 @@ class RMSWeightTemplate(metaclass=ABCMeta):
...
@@ -63,31 +84,6 @@ class RMSWeightTemplate(metaclass=ABCMeta):
else
:
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
_calculate_size
(
self
):
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
@
RMS_WEIGHT_REGISTER
(
"Default"
)
class
RMSWeight
(
RMSWeightTemplate
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
apply
(
self
,
input_tensor
):
if
GET_SENSITIVE_DTYPE
()
!=
GET_DTYPE
():
input_tensor
=
self
.
_norm
(
input_tensor
).
type_as
(
input_tensor
)
*
self
.
weight
else
:
input_tensor
=
self
.
_norm
(
input_tensor
.
float
()).
type_as
(
input_tensor
)
*
self
.
weight
return
input_tensor
def
state_dict
(
self
,
destination
=
None
):
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
if
destination
is
None
:
destination
=
{}
destination
=
{}
...
@@ -106,6 +102,32 @@ class RMSWeight(RMSWeightTemplate):
...
@@ -106,6 +102,32 @@ class RMSWeight(RMSWeightTemplate):
return
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
@
RMS_WEIGHT_REGISTER
(
"Default"
)
class
RMSWeight
(
RMSWeightTemplate
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
apply
(
self
,
input_tensor
):
if
GET_SENSITIVE_DTYPE
()
!=
GET_DTYPE
():
input_tensor
=
self
.
_norm
(
input_tensor
).
type_as
(
input_tensor
)
*
self
.
weight
else
:
input_tensor
=
self
.
_norm
(
input_tensor
.
float
()).
type_as
(
input_tensor
)
*
self
.
weight
return
input_tensor
@
RMS_WEIGHT_REGISTER
(
"sgl-kernel"
)
@
RMS_WEIGHT_REGISTER
(
"sgl-kernel"
)
class
RMSWeightSgl
(
RMSWeight
):
class
RMSWeightSgl
(
RMSWeight
):
...
@@ -113,18 +135,13 @@ class RMSWeightSgl(RMSWeight):
...
@@ -113,18 +135,13 @@ class RMSWeightSgl(RMSWeight):
self
,
self
,
weight_name
,
weight_name
,
create_cuda_buffer
=
False
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
is_post_adapter
=
False
,
eps
=
1e-6
,
eps
=
1e-6
,
):
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
if
sgl_kernel
is
not
None
and
self
.
sensitive_layer_dtype
==
self
.
infer_dtype
:
if
sgl_kernel
is
not
None
and
self
.
sensitive_layer_dtype
==
self
.
infer_dtype
:
...
@@ -146,8 +163,8 @@ class RMSWeightSgl(RMSWeight):
...
@@ -146,8 +163,8 @@ class RMSWeightSgl(RMSWeight):
@
RMS_WEIGHT_REGISTER
(
"fp32_variance"
)
@
RMS_WEIGHT_REGISTER
(
"fp32_variance"
)
class
RMSWeightFP32
(
RMSWeight
):
class
RMSWeightFP32
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
input_dtype
=
input_tensor
.
dtype
input_dtype
=
input_tensor
.
dtype
...
@@ -165,8 +182,8 @@ class RMSWeightFP32(RMSWeight):
...
@@ -165,8 +182,8 @@ class RMSWeightFP32(RMSWeight):
@
RMS_WEIGHT_REGISTER
(
"self_forcing"
)
@
RMS_WEIGHT_REGISTER
(
"self_forcing"
)
class
RMSWeightSF
(
RMSWeight
):
class
RMSWeightSF
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
_norm
(
self
,
x
):
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
74eeb429
...
@@ -4,52 +4,64 @@ import torch
...
@@ -4,52 +4,64 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
TENSOR_REGISTER
(
"Default"
)
@
TENSOR_REGISTER
(
"Default"
)
class
DefaultTensor
:
class
DefaultTensor
:
def
__init__
(
self
,
tensor_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
tensor_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
tensor_name
=
tensor_name
self
.
tensor_name
=
tensor_name
self
.
lazy_load
=
lazy_load
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
is_post_adapter
=
is_post_adapter
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
infer_dtype
=
GET_DTYPE
()
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load_from_disk
(
self
):
def
load
(
self
,
weight_dict
):
if
not
torch
.
_dynamo
.
is_compiling
():
if
self
.
create_cuda_buffer
:
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
).
pin_memory
()
self
.
_load_cuda_buffer
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffer
()
else
:
else
:
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
_load_default_tensors
(
weight_dict
)
def
load
(
self
,
weight_dict
):
def
_
load
_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
device
=
weight_dict
[
self
.
tensor_name
].
device
self
.
tensor_cuda_buffer
=
weight_dict
[
self
.
tensor_name
].
cuda
()
if
device
.
type
==
"cpu"
:
tensor
=
weight_dict
[
self
.
tensor_name
]
self
.
pin_tensor
=
self
.
_create_cpu_pin_tensor
(
tensor
)
del
weight_dict
[
self
.
tensor_name
]
else
:
else
:
device
=
weight_dict
[
self
.
tensor_name
].
device
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
if
device
.
type
==
"cpu"
:
tensor_shape
=
weight_dict
[
self
.
tensor_name
].
shape
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
tensor_dtype
=
weight_dict
[
self
.
tensor_name
].
dtype
if
self
.
lazy_load
:
self
.
pin_tensor
=
torch
.
empty
(
tensor_shape
,
pin_memory
=
True
,
dtype
=
tensor_dtype
)
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
self
.
pin_tensor
.
copy_
(
weight_dict
[
self
.
tensor_name
])
if
use_infer_dtype
:
del
weight_dict
[
self
.
tensor_name
]
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
else
:
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
tensor
=
weight_dict
[
self
.
tensor_name
]
return
tensor
def
clear
(
self
):
attrs
=
[
"tensor"
,
"pinned_tensor"
]
def
_create_cpu_pin_tensor
(
self
,
tensor
):
for
attr
in
attrs
:
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
if
hasattr
(
self
,
attr
):
pin_tensor
.
copy_
(
tensor
)
delattr
(
self
,
attr
)
del
tensor
setattr
(
self
,
attr
,
None
)
return
pin_tensor
def
_calculate_size
(
self
):
def
_load_cuda_buffer
(
self
,
weight_dict
):
return
self
.
tensor
.
numel
()
*
self
.
tensor
.
element_size
()
tensor
=
self
.
_get_tensor
(
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
self
.
tensor_cuda_buffer
=
tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffer
(
self
):
tensor
=
self
.
_get_tensor
(
use_infer_dtype
=
True
)
self
.
pin_tensor
=
self
.
_create_cpu_pin_tensor
(
tensor
)
def
to_cuda
(
self
,
non_blocking
=
False
):
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
tensor
=
self
.
pin_tensor
.
cuda
(
non_blocking
=
non_blocking
)
self
.
tensor
=
self
.
pin_tensor
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_tensor"
):
if
hasattr
(
self
,
"pin_tensor"
):
...
@@ -69,8 +81,18 @@ class DefaultTensor:
...
@@ -69,8 +81,18 @@ class DefaultTensor:
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
else
:
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
if
tensor_name
not
in
destination
:
if
tensor_name
not
in
destination
:
self
.
tensor
=
None
self
.
tensor
=
None
return
return
self
.
tensor
=
self
.
tensor_cuda_buffer
.
copy_
(
destination
[
tensor_name
],
non_blocking
=
True
)
self
.
tensor
=
self
.
tensor_cuda_buffer
.
copy_
(
destination
[
tensor_name
],
non_blocking
=
True
)
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
del
tensor
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
100644 → 100755
View file @
74eeb429
...
@@ -62,17 +62,13 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -62,17 +62,13 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
dtype
=
torch
.
bfloat16
self
.
load
()
self
.
load
()
def
load
(
self
):
def
load
(
self
):
self
.
text_encoder
=
Qwen2_5_VLForConditionalGeneration
.
from_pretrained
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"text_encoder"
),
torch_dtype
=
torch
.
bfloat16
)
self
.
text_encoder
=
Qwen2_5_VLForConditionalGeneration
.
from_pretrained
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"text_encoder"
),
torch_dtype
=
torch
.
bfloat16
)
if
not
self
.
cpu_offload
:
if
not
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
device
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
AI_DEVICE
)
self
.
tokenizer
=
Qwen2Tokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"tokenizer"
))
self
.
tokenizer
=
Qwen2Tokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"tokenizer"
))
if
self
.
config
[
"task"
]
==
"i2i"
:
if
self
.
config
[
"task"
]
==
"i2i"
:
...
@@ -98,7 +94,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -98,7 +94,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
image_list
=
None
):
def
infer
(
self
,
text
,
image_list
=
None
):
if
self
.
cpu_offload
:
if
self
.
cpu_offload
:
self
.
text_encoder
.
to
(
self
.
device
)
self
.
text_encoder
.
to
(
AI_DEVICE
)
if
image_list
is
not
None
:
if
image_list
is
not
None
:
condition_image_list
=
[]
condition_image_list
=
[]
...
@@ -133,7 +129,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -133,7 +129,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images
=
condition_image_list
,
images
=
condition_image_list
,
padding
=
True
,
padding
=
True
,
return_tensors
=
"pt"
,
return_tensors
=
"pt"
,
).
to
(
torch
.
device
(
self
.
device
)
)
).
to
(
AI_DEVICE
)
encoder_hidden_states
=
self
.
text_encoder
(
encoder_hidden_states
=
self
.
text_encoder
(
input_ids
=
model_inputs
.
input_ids
,
input_ids
=
model_inputs
.
input_ids
,
...
@@ -156,7 +152,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -156,7 +152,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
txt
=
[
template
.
format
(
e
)
for
e
in
text
]
txt
=
[
template
.
format
(
e
)
for
e
in
text
]
image_info
=
{}
image_info
=
{}
model_inputs
=
self
.
tokenizer
(
txt
,
max_length
=
self
.
tokenizer_max_length
+
drop_idx
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
).
to
(
self
.
device
)
model_inputs
=
self
.
tokenizer
(
txt
,
max_length
=
self
.
tokenizer_max_length
+
drop_idx
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
).
to
(
AI_DEVICE
)
encoder_hidden_states
=
self
.
text_encoder
(
encoder_hidden_states
=
self
.
text_encoder
(
input_ids
=
model_inputs
.
input_ids
,
input_ids
=
model_inputs
.
input_ids
,
attention_mask
=
model_inputs
.
attention_mask
,
attention_mask
=
model_inputs
.
attention_mask
,
...
@@ -172,7 +168,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
...
@@ -172,7 +168,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
max_seq_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
split_hidden_states
])
prompt_embeds
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
max_seq_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
split_hidden_states
])
encoder_attention_mask
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
max_seq_len
-
u
.
size
(
0
))])
for
u
in
attn_mask_list
])
encoder_attention_mask
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
max_seq_len
-
u
.
size
(
0
))])
for
u
in
attn_mask_list
])
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
AI_DEVICE
)
prompt_embeds_mask
=
encoder_attention_mask
prompt_embeds_mask
=
encoder_attention_mask
_
,
seq_len
,
_
=
prompt_embeds
.
shape
_
,
seq_len
,
_
=
prompt_embeds
.
shape
...
...
lightx2v/models/input_encoders/hf/wan/t5/model.py
View file @
74eeb429
...
@@ -515,7 +515,7 @@ class T5Encoder(nn.Module):
...
@@ -515,7 +515,7 @@ class T5Encoder(nn.Module):
e
=
pos_bias
e
=
pos_bias
else
:
else
:
lq
,
lk
=
x
.
size
(
1
),
x
.
size
(
1
)
lq
,
lk
=
x
.
size
(
1
),
x
.
size
(
1
)
rel_pos
=
torch
.
arange
(
lk
,
device
=
"cuda"
).
unsqueeze
(
0
)
-
torch
.
arange
(
lq
,
device
=
"cuda"
).
unsqueeze
(
1
)
rel_pos
=
torch
.
arange
(
lk
,
device
=
AI_DEVICE
).
unsqueeze
(
0
)
-
torch
.
arange
(
lq
,
device
=
AI_DEVICE
).
unsqueeze
(
1
)
num_buckets
=
block
.
pos_embedding
.
weight
.
shape
[
0
]
//
2
num_buckets
=
block
.
pos_embedding
.
weight
.
shape
[
0
]
//
2
rel_buckets
=
(
rel_pos
>
0
).
long
()
*
num_buckets
rel_buckets
=
(
rel_pos
>
0
).
long
()
*
num_buckets
rel_pos
=
torch
.
abs
(
rel_pos
)
rel_pos
=
torch
.
abs
(
rel_pos
)
...
@@ -532,28 +532,21 @@ class T5Encoder(nn.Module):
...
@@ -532,28 +532,21 @@ class T5Encoder(nn.Module):
return
x
return
x
def
forward_with_offload
(
self
,
ids
,
mask
=
None
):
def
forward_with_offload
(
self
,
ids
,
mask
=
None
):
self
.
token_embedding
=
self
.
token_embedding
.
to
(
"cuda"
)
self
.
token_embedding
=
self
.
token_embedding
.
to
(
AI_DEVICE
)
self
.
pos_embedding
=
self
.
pos_embedding
.
to
(
"cuda"
)
if
self
.
pos_embedding
is
not
None
else
None
self
.
pos_embedding
=
self
.
pos_embedding
.
to
(
AI_DEVICE
)
if
self
.
pos_embedding
is
not
None
else
None
x
=
self
.
token_embedding
(
ids
)
x
=
self
.
token_embedding
(
ids
)
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
self
.
norm
=
self
.
norm
.
to
(
"cuda"
)
self
.
norm
=
self
.
norm
.
to
(
AI_DEVICE
)
for
block_idx
in
range
(
len
(
self
.
blocks
)):
for
block_idx
in
range
(
len
(
self
.
blocks
)):
self
.
block_idx
=
block_idx
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
self
.
offload_manager
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
offload_manager
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
blocks
[
block_idx
].
state_dict
(),
self
.
blocks
[
block_idx
].
state_dict
(),
block_idx
,
block_idx
,
)
)
x
=
self
.
forward_block_with_offload
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
mask
,
pos_bias
=
e
)
if
block_idx
<
len
(
self
.
blocks
)
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
self
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
forward_block_with_offload
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
mask
,
pos_bias
=
e
)
self
.
offload_manager
.
swap_blocks
()
x
=
self
.
norm
(
x
)
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
dropout
(
x
)
...
...
lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py
View file @
74eeb429
...
@@ -6,6 +6,7 @@ import torch
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer
import
HunyuanVideo15OffloadTransformerInfer
from
lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer
import
HunyuanVideo15OffloadTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
HunyuanVideo15TransformerInferMagCaching
(
HunyuanVideo15OffloadTransformerInfer
):
class
HunyuanVideo15TransformerInferMagCaching
(
HunyuanVideo15OffloadTransformerInfer
):
...
@@ -101,8 +102,8 @@ class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerI
...
@@ -101,8 +102,8 @@ class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerI
def
infer_using_cache
(
self
,
infer_module_out
):
def
infer_using_cache
(
self
,
infer_module_out
):
residual_img
=
self
.
residual_cache
[
self
.
scheduler
.
infer_condition
]
residual_img
=
self
.
residual_cache
[
self
.
scheduler
.
infer_condition
]
residual_txt
=
self
.
residual_cache_txt
[
self
.
scheduler
.
infer_condition
]
residual_txt
=
self
.
residual_cache_txt
[
self
.
scheduler
.
infer_condition
]
infer_module_out
.
img
.
add_
(
residual_img
.
cuda
(
))
infer_module_out
.
img
.
add_
(
residual_img
.
to
(
AI_DEVICE
))
infer_module_out
.
txt
.
add_
(
residual_txt
.
cuda
(
))
infer_module_out
.
txt
.
add_
(
residual_txt
.
to
(
AI_DEVICE
))
def
clear
(
self
):
def
clear
(
self
):
self
.
accumulated_err
=
{
True
:
0.0
,
False
:
0.0
}
self
.
accumulated_err
=
{
True
:
0.0
,
False
:
0.0
}
...
...
lightx2v/models/networks/hunyuan_video/infer/offload/transformer_infer.py
View file @
74eeb429
...
@@ -2,6 +2,9 @@ import torch
...
@@ -2,6 +2,9 @@ import torch
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.models.networks.hunyuan_video.infer.transformer_infer
import
HunyuanVideo15TransformerInfer
from
lightx2v.models.networks.hunyuan_video.infer.transformer_infer
import
HunyuanVideo15TransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
HunyuanVideo15OffloadTransformerInfer
(
HunyuanVideo15TransformerInfer
):
class
HunyuanVideo15OffloadTransformerInfer
(
HunyuanVideo15TransformerInfer
):
...
@@ -26,6 +29,6 @@ class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer):
...
@@ -26,6 +29,6 @@ class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer):
self
.
offload_manager
.
init_first_buffer
(
weights
.
double_blocks
)
self
.
offload_manager
.
init_first_buffer
(
weights
.
double_blocks
)
if
block_idx
<
self
.
double_blocks_num
-
1
:
if
block_idx
<
self
.
double_blocks_num
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
weights
.
double_blocks
)
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
weights
.
double_blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
with
torch
_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
infer_module_out
.
img
,
infer_module_out
.
txt
=
self
.
infer_double_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
infer_module_out
)
infer_module_out
.
img
,
infer_module_out
.
txt
=
self
.
infer_double_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
infer_module_out
)
self
.
offload_manager
.
swap_blocks
()
self
.
offload_manager
.
swap_blocks
()
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
View file @
74eeb429
...
@@ -234,16 +234,11 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
...
@@ -234,16 +234,11 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
attention_module
=
weights
.
self_attention
,
attention_module
=
weights
.
self_attention
,
seq_p_group
=
self
.
seq_p_group
,
seq_p_group
=
self
.
seq_p_group
,
use_fp8_comm
=
self
.
seq_p_fp8_comm
,
use_fp8_comm
=
self
.
seq_p_fp8_comm
,
model_cls
=
self
.
config
[
"model_cls"
],
)
)
else
:
else
:
attn_out
=
weights
.
self_attention
.
apply
(
attn_out
=
weights
.
self_attention
.
apply
(
q
=
query
,
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
seqlen
,
max_seqlen_kv
=
seqlen
,
model_cls
=
self
.
config
[
"model_cls"
]
k
=
key
,
v
=
value
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
seqlen
,
max_seqlen_kv
=
seqlen
,
)
)
img_attn
,
txt_attn
=
attn_out
[:
img_seqlen
],
attn_out
[
img_seqlen
:]
img_attn
,
txt_attn
=
attn_out
[:
img_seqlen
],
attn_out
[
img_seqlen
:]
...
...
lightx2v/models/networks/hunyuan_video/model.py
View file @
74eeb429
...
@@ -32,7 +32,8 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
...
@@ -32,7 +32,8 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
self
.
seq_p_group
=
None
self
.
seq_p_group
=
None
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
remove_keys
=
[
"byt5_in"
,
"vision_in"
]
self
.
remove_keys
=
[]
self
.
remove_keys
.
extend
([
"byt5_in"
,
"vision_in"
])
self
.
dit_quantized
=
self
.
config
.
get
(
"dit_quantized"
,
False
)
self
.
dit_quantized
=
self
.
config
.
get
(
"dit_quantized"
,
False
)
if
self
.
dit_quantized
:
if
self
.
dit_quantized
:
assert
self
.
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
in
[
assert
self
.
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
in
[
...
@@ -98,7 +99,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
...
@@ -98,7 +99,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_buffers
,
self
.
transformer_weights
.
offload_phase_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_
cuda_
buffers
,
self
.
transformer_weights
.
offload_phase_
cuda_
buffers
)
def
set_scheduler
(
self
,
scheduler
):
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
self
.
scheduler
=
scheduler
...
@@ -176,7 +177,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
...
@@ -176,7 +177,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
[
"parallel"
]
:
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
device
=
dist
.
get_rank
()
device
=
dist
.
get_rank
()
else
:
else
:
device
=
str
(
self
.
device
)
device
=
str
(
self
.
device
)
...
...
lightx2v/models/networks/hunyuan_video/weights/transformer_weights.py
View file @
74eeb429
...
@@ -24,7 +24,7 @@ class HunyuanVideo15TransformerWeights(WeightModule):
...
@@ -24,7 +24,7 @@ class HunyuanVideo15TransformerWeights(WeightModule):
if
config
[
"cpu_offload"
]:
if
config
[
"cpu_offload"
]:
if
config
.
get
(
"offload_granularity"
,
"block"
)
==
"block"
:
if
config
.
get
(
"offload_granularity"
,
"block"
)
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
self
.
offload_block_
cuda_
buffers
=
WeightModuleList
(
[
[
MMDoubleStreamBlock
(
MMDoubleStreamBlock
(
i
,
i
,
...
@@ -36,8 +36,8 @@ class HunyuanVideo15TransformerWeights(WeightModule):
...
@@ -36,8 +36,8 @@ class HunyuanVideo15TransformerWeights(WeightModule):
for
i
in
range
(
self
.
offload_blocks_num
)
for
i
in
range
(
self
.
offload_blocks_num
)
]
]
)
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
add_module
(
"offload_block_
cuda_
buffers"
,
self
.
offload_block_
cuda_
buffers
)
self
.
offload_phase_buffers
=
None
self
.
offload_phase_
cuda_
buffers
=
None
def
non_block_weights_to_cuda
(
self
):
def
non_block_weights_to_cuda
(
self
):
self
.
final_layer
.
to_cuda
()
self
.
final_layer
.
to_cuda
()
...
@@ -47,23 +47,24 @@ class HunyuanVideo15TransformerWeights(WeightModule):
...
@@ -47,23 +47,24 @@ class HunyuanVideo15TransformerWeights(WeightModule):
class
MMDoubleStreamBlock
(
WeightModule
):
class
MMDoubleStreamBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
task
=
task
self
.
task
=
task
self
.
config
=
config
self
.
config
=
config
self
.
is_offload_buffer
=
is_offload_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
False
self
.
lazy_load
=
False
self
.
lazy_load_file
=
None
self
.
lazy_load_file
=
None
self
.
add_module
(
self
.
add_module
(
"img_branch"
,
"img_branch"
,
MMDoubleStreamBlockImgBranch
(
block_index
,
task
,
config
,
block_prefix
,
is_offload
_buffer
),
MMDoubleStreamBlockImgBranch
(
block_index
,
task
,
config
,
block_prefix
,
create_cuda_buffer
,
create_cpu
_buffer
),
)
)
self
.
add_module
(
self
.
add_module
(
"txt_branch"
,
"txt_branch"
,
MMDoubleStreamBlockTxtBranch
(
block_index
,
task
,
config
,
block_prefix
,
is_offload
_buffer
),
MMDoubleStreamBlockTxtBranch
(
block_index
,
task
,
config
,
block_prefix
,
create_cuda_buffer
,
create_cpu
_buffer
),
)
)
attention_weights_cls
=
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attn_type"
]]
attention_weights_cls
=
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attn_type"
]]
self
.
add_module
(
"self_attention"
,
attention_weights_cls
())
self
.
add_module
(
"self_attention"
,
attention_weights_cls
())
...
@@ -75,7 +76,7 @@ class MMDoubleStreamBlock(WeightModule):
...
@@ -75,7 +76,7 @@ class MMDoubleStreamBlock(WeightModule):
class
MMDoubleStreamBlockImgBranch
(
WeightModule
):
class
MMDoubleStreamBlockImgBranch
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
task
=
task
self
.
task
=
task
...
@@ -93,7 +94,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -93,7 +94,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.linear.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.linear.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -103,6 +105,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -103,6 +105,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -112,7 +116,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -112,7 +116,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -122,7 +127,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -122,7 +127,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -132,7 +138,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -132,7 +138,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -141,7 +148,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -141,7 +148,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
"img_attn_q_norm"
,
"img_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -150,7 +158,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -150,7 +158,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
"img_attn_k_norm"
,
"img_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -160,7 +169,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -160,7 +169,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -170,6 +180,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -170,6 +180,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -179,7 +191,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -179,7 +191,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -189,7 +202,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -189,7 +202,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -197,7 +211,7 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
...
@@ -197,7 +211,7 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
class
MMDoubleStreamBlockTxtBranch
(
WeightModule
):
class
MMDoubleStreamBlockTxtBranch
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
super
().
__init__
()
self
.
block_index
=
block_index
self
.
block_index
=
block_index
self
.
task
=
task
self
.
task
=
task
...
@@ -215,7 +229,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -215,7 +229,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.linear.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.linear.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -225,6 +240,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -225,6 +240,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -234,7 +251,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -234,7 +251,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -244,7 +262,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -244,7 +262,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -254,7 +273,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -254,7 +273,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_v.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -263,7 +283,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -263,7 +283,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
"txt_attn_q_norm"
,
"txt_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -272,7 +293,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -272,7 +293,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
"txt_attn_k_norm"
,
"txt_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -282,7 +304,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -282,7 +304,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_proj.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -292,6 +315,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -292,6 +315,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -301,7 +326,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -301,7 +326,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -311,7 +337,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
...
@@ -311,7 +337,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -333,6 +360,8 @@ class FinalLayerWeights(WeightModule):
...
@@ -333,6 +360,8 @@ class FinalLayerWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
,
"final_layer.adaLN_modulation.1.bias"
,
False
,
False
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -342,6 +371,8 @@ class FinalLayerWeights(WeightModule):
...
@@ -342,6 +371,8 @@ class FinalLayerWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.linear.weight"
,
"final_layer.linear.weight"
,
"final_layer.linear.bias"
,
"final_layer.linear.bias"
,
False
,
False
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
@@ -351,6 +382,8 @@ class FinalLayerWeights(WeightModule):
...
@@ -351,6 +382,8 @@ class FinalLayerWeights(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
None
,
None
,
False
,
False
,
self
.
lazy_load
,
self
.
lazy_load
,
self
.
lazy_load_file
,
self
.
lazy_load_file
,
),
),
...
...
lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py
View file @
74eeb429
...
@@ -2,6 +2,9 @@ import torch
...
@@ -2,6 +2,9 @@ import torch
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.models.networks.qwen_image.infer.transformer_infer
import
QwenImageTransformerInfer
from
lightx2v.models.networks.qwen_image.infer.transformer_infer
import
QwenImageTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
QwenImageOffloadTransformerInfer
(
QwenImageTransformerInfer
):
class
QwenImageOffloadTransformerInfer
(
QwenImageTransformerInfer
):
...
@@ -37,7 +40,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
...
@@ -37,7 +40,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
if
block_idx
<
self
.
num_blocks
-
1
:
if
block_idx
<
self
.
num_blocks
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
block_weights
.
blocks
)
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
block_weights
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
with
torch
_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
encoder_hidden_states
,
hidden_states
=
self
.
infer_block
(
encoder_hidden_states
,
hidden_states
=
self
.
infer_block
(
block_weight
=
self
.
offload_manager
.
cuda_buffers
[
0
],
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
block_weight
=
self
.
offload_manager
.
cuda_buffers
[
0
],
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
)
)
...
...
lightx2v/models/networks/qwen_image/model.py
View file @
74eeb429
...
@@ -8,7 +8,6 @@ from safetensors import safe_open
...
@@ -8,7 +8,6 @@ from safetensors import safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.infer.offload.transformer_infer
import
QwenImageOffloadTransformerInfer
from
.infer.offload.transformer_infer
import
QwenImageOffloadTransformerInfer
from
.infer.post_infer
import
QwenImagePostInfer
from
.infer.post_infer
import
QwenImagePostInfer
...
@@ -125,7 +124,7 @@ class QwenImageTransformerModel:
...
@@ -125,7 +124,7 @@ class QwenImageTransformerModel:
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
[
"parallel"
]
:
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
device
=
dist
.
get_rank
()
device
=
dist
.
get_rank
()
else
:
else
:
device
=
str
(
self
.
device
)
device
=
str
(
self
.
device
)
...
@@ -284,7 +283,7 @@ class QwenImageTransformerModel:
...
@@ -284,7 +283,7 @@ class QwenImageTransformerModel:
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_buffers
,
self
.
transformer_weights
.
offload_phase_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_
cuda_
buffers
,
self
.
transformer_weights
.
offload_phase_
cuda_
buffers
)
def
to_cpu
(
self
):
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
self
.
pre_weight
.
to_cpu
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment