Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
74eeb429
Unverified
Commit
74eeb429
authored
Dec 03, 2025
by
Gu Shiqiao
Committed by
GitHub
Dec 03, 2025
Browse files
reconstruct disk offload and fix lightx2v_platform bugs (#558)
Co-authored-by:
helloyongyang
<
yongyang1030@163.com
>
parent
f7cdbcb5
Changes
46
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
671 additions
and
826 deletions
+671
-826
lightx2v/common/modules/weight_module.py
lightx2v/common/modules/weight_module.py
+8
-29
lightx2v/common/offload/manager.py
lightx2v/common/offload/manager.py
+46
-362
lightx2v/common/ops/attn/flash_attn.py
lightx2v/common/ops/attn/flash_attn.py
+4
-3
lightx2v/common/ops/attn/template.py
lightx2v/common/ops/attn/template.py
+3
-0
lightx2v/common/ops/conv/conv2d.py
lightx2v/common/ops/conv/conv2d.py
+5
-11
lightx2v/common/ops/conv/conv3d.py
lightx2v/common/ops/conv/conv3d.py
+3
-9
lightx2v/common/ops/embedding/embedding_weight.py
lightx2v/common/ops/embedding/embedding_weight.py
+5
-3
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+287
-170
lightx2v/common/ops/norm/layer_norm_weight.py
lightx2v/common/ops/norm/layer_norm_weight.py
+93
-85
lightx2v/common/ops/norm/rms_norm_weight.py
lightx2v/common/ops/norm/rms_norm_weight.py
+73
-56
lightx2v/common/ops/tensor/tensor.py
lightx2v/common/ops/tensor/tensor.py
+51
-29
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
...t_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
+5
-9
lightx2v/models/input_encoders/hf/wan/t5/model.py
lightx2v/models/input_encoders/hf/wan/t5/model.py
+9
-16
lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py
.../hunyuan_video/infer/feature_caching/transformer_infer.py
+3
-2
lightx2v/models/networks/hunyuan_video/infer/offload/transformer_infer.py
...networks/hunyuan_video/infer/offload/transformer_infer.py
+4
-1
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
.../models/networks/hunyuan_video/infer/transformer_infer.py
+2
-7
lightx2v/models/networks/hunyuan_video/model.py
lightx2v/models/networks/hunyuan_video/model.py
+4
-3
lightx2v/models/networks/hunyuan_video/weights/transformer_weights.py
...els/networks/hunyuan_video/weights/transformer_weights.py
+60
-27
lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py
...ls/networks/qwen_image/infer/offload/transformer_infer.py
+4
-1
lightx2v/models/networks/qwen_image/model.py
lightx2v/models/networks/qwen_image/model.py
+2
-3
No files found.
lightx2v/common/modules/weight_module.py
View file @
74eeb429
...
...
@@ -23,35 +23,6 @@ class WeightModule:
if
hasattr
(
parameter
,
"load"
):
parameter
.
load
(
weight_dict
)
def
calculate_size
(
self
):
total_size
=
0
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"_calculate_size"
):
total_size
+=
module
.
_calculate_size
()
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"_calculate_size"
):
total_size
+=
parameter
.
_calculate_size
()
return
total_size
def
load_from_disk
(
self
):
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"load_from_disk"
):
module
.
load_from_disk
()
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"load_from_disk"
):
parameter
.
load_from_disk
()
def
clear
(
self
):
for
_
,
module
in
self
.
_modules
.
items
():
if
hasattr
(
module
,
"clear"
):
module
.
clear
()
for
_
,
parameter
in
self
.
_parameters
.
items
():
if
hasattr
(
parameter
,
"clear"
):
parameter
.
clear
()
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
...
...
@@ -74,6 +45,14 @@ class WeightModule:
module
.
load_state_dict
(
destination
,
block_index
,
adapter_block_index
)
return
destination
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
for
_
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
param
.
load_state_dict_from_disk
(
block_index
,
adapter_block_index
)
for
_
,
module
in
self
.
_modules
.
items
():
if
module
is
not
None
:
module
.
load_state_dict_from_disk
(
block_index
,
adapter_block_index
)
def
named_parameters
(
self
,
prefix
=
""
):
for
name
,
param
in
self
.
_parameters
.
items
():
if
param
is
not
None
:
...
...
lightx2v/common/offload/manager.py
View file @
74eeb429
import
gc
import
queue
import
threading
import
time
from
collections
import
OrderedDict
import
torch
from
loguru
import
logger
from
packaging.version
import
parse
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
WeightAsyncStreamManager
(
object
):
def
__init__
(
self
,
offload_granularity
):
self
.
offload_granularity
=
offload_granularity
self
.
init_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
init_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
need_init_first_buffer
=
True
torch_version
=
parse
(
torch
.
__version__
.
split
(
"+"
)[
0
])
if
torch_version
>=
parse
(
"2.7"
):
self
.
cuda_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
1
)
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=
1
)
if
AI_DEVICE
==
"cuda"
and
torch_version
>=
parse
(
"2.7"
):
self
.
cuda_load_stream
=
torch
_device_module
.
Stream
(
priority
=
1
)
self
.
compute_stream
=
torch
_device_module
.
Stream
(
priority
=
1
)
else
:
self
.
cuda_load_stream
=
torch
.
cuda
.
Stream
(
priority
=
0
)
self
.
compute_stream
=
torch
.
cuda
.
Stream
(
priority
=-
1
)
self
.
cuda_load_stream
=
torch_device_module
.
Stream
(
priority
=
0
)
self
.
compute_stream
=
torch_device_module
.
Stream
(
priority
=-
1
)
def
init_cpu_buffer
(
self
,
blocks_cpu_buffer
=
None
,
phases_cpu_buffer
=
None
):
self
.
need_init_first_buffer
=
True
if
self
.
offload_granularity
==
"block"
:
assert
blocks_cpu_buffer
is
not
None
self
.
cpu_buffers
=
[
blocks_cpu_buffer
[
i
]
for
i
in
range
(
len
(
blocks_cpu_buffer
))]
elif
self
.
offload_granularity
==
"phase"
:
assert
phases_cpu_buffer
is
not
None
self
.
cpu_buffers
=
[
phases_cpu_buffer
[
i
]
for
i
in
range
(
len
(
phases_cpu_buffer
))]
else
:
raise
NotImplementedError
def
init_cuda_buffer
(
self
,
blocks_cuda_buffer
=
None
,
phases_cuda_buffer
=
None
):
self
.
need_init_first_buffer
=
True
if
self
.
offload_granularity
==
"block"
:
assert
blocks_cuda_buffer
is
not
None
self
.
cuda_buffers
=
[
blocks_cuda_buffer
[
i
]
for
i
in
range
(
len
(
blocks_cuda_buffer
))]
...
...
@@ -32,17 +42,32 @@ class WeightAsyncStreamManager(object):
raise
NotImplementedError
def
init_first_buffer
(
self
,
blocks
,
adapter_block_idx
=
None
):
if
self
.
offload_granularity
==
"block"
:
with
torch
.
cuda
.
stream
(
self
.
init_stream
):
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
else
:
with
torch
.
cuda
.
stream
(
self
.
init_stream
):
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
compute_phases
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
with
torch_device_module
.
stream
(
self
.
init_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
cpu_buffers
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
else
:
if
self
.
offload_granularity
==
"block"
:
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
else
:
self
.
cuda_buffers
[
0
].
load_state_dict
(
blocks
[
0
].
compute_phases
[
0
].
state_dict
(),
0
,
adapter_block_idx
)
self
.
init_stream
.
synchronize
()
self
.
need_init_first_buffer
=
False
def
prefetch_weights
(
self
,
block_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
self
.
cuda_buffers
[
1
].
load_state_dict
(
blocks
[
block_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cpu_buffers
[
1
].
load_state_dict_from_disk
(
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
1
].
load_state_dict
(
self
.
cpu_buffers
[
1
].
state_dict
(),
block_idx
,
adapter_block_idx
)
else
:
self
.
cuda_buffers
[
1
].
load_state_dict
(
blocks
[
block_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch_device_module
.
stream
(
self
.
cuda_load_stream
):
if
hasattr
(
self
,
"cpu_buffers"
):
self
.
cpu_buffers
[
phase_idx
].
load_state_dict_from_disk
(
block_idx
,
adapter_block_idx
)
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
self
.
cpu_buffers
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
else
:
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
def
swap_blocks
(
self
):
self
.
cuda_load_stream
.
synchronize
()
...
...
@@ -52,347 +77,6 @@ class WeightAsyncStreamManager(object):
self
.
cuda_buffers
[
0
],
)
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
,
adapter_block_idx
=
None
):
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
self
.
cuda_buffers
[
phase_idx
].
load_state_dict
(
blocks
[
block_idx
].
compute_phases
[
phase_idx
].
state_dict
(),
block_idx
,
adapter_block_idx
)
def
swap_phases
(
self
):
self
.
cuda_load_stream
.
synchronize
()
self
.
compute_stream
.
synchronize
()
class
LazyWeightAsyncStreamManager
(
WeightAsyncStreamManager
):
def
__init__
(
self
,
blocks_num
,
offload_ratio
=
1
,
phases_num
=
1
,
num_disk_workers
=
1
,
max_memory
=
2
,
offload_gra
=
"phase"
,
):
super
().
__init__
(
blocks_num
,
offload_ratio
,
phases_num
)
self
.
offload_gra
=
offload_gra
self
.
worker_stop_event
=
threading
.
Event
()
self
.
pin_memory_buffer
=
MemoryBuffer
(
max_memory
*
(
1024
**
3
))
self
.
disk_task_queue
=
queue
.
PriorityQueue
()
self
.
disk_workers
=
[]
self
.
release_workers
=
[]
self
.
_start_disk_workers
(
num_disk_workers
)
self
.
initial_prefetch_done
=
False
self
.
pending_tasks
=
{}
self
.
task_lock
=
threading
.
Lock
()
self
.
last_used_time
=
{}
self
.
time_lock
=
threading
.
Lock
()
def
_start_disk_workers
(
self
,
num_workers
):
for
i
in
range
(
num_workers
):
if
self
.
offload_gra
==
"phase"
:
worker
=
threading
.
Thread
(
target
=
self
.
_disk_worker_loop
,
daemon
=
True
)
else
:
worker
=
threading
.
Thread
(
target
=
self
.
_disk_worker_loop_block
,
daemon
=
True
)
worker
.
start
()
self
.
disk_workers
.
append
(
worker
)
def
_disk_worker_loop
(
self
):
while
not
self
.
worker_stop_event
.
is_set
():
try
:
_
,
task
=
self
.
disk_task_queue
.
get
(
timeout
=
0.5
)
if
task
is
None
:
break
block_idx
,
phase_idx
,
phase
=
task
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
((
block_idx
,
phase_idx
),
phase
)
with
self
.
task_lock
:
if
(
block_idx
,
phase_idx
)
in
self
.
pending_tasks
:
del
self
.
pending_tasks
[(
block_idx
,
phase_idx
)]
except
queue
.
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
f
"Disk worker thread error:
{
e
}
"
)
def
_disk_worker_loop_block
(
self
):
while
not
self
.
worker_stop_event
.
is_set
():
try
:
_
,
task
=
self
.
disk_task_queue
.
get
(
timeout
=
0.5
)
if
task
is
None
:
break
block_idx
,
block
=
task
for
phase
in
block
.
compute_phases
:
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
(
block_idx
,
block
)
with
self
.
task_lock
:
if
block_idx
in
self
.
pending_tasks
:
del
self
.
pending_tasks
[
block_idx
]
except
queue
.
Empty
:
continue
except
Exception
as
e
:
logger
.
error
(
f
"Disk worker thread error:
{
e
}
"
)
def
_async_prefetch_block
(
self
,
blocks
,
next_block_idx
=
None
):
if
next_block_idx
is
None
:
next_block_idx
=
self
.
pin_memory_buffer
.
get_max_block_index
()
if
next_block_idx
<
0
:
next_block_idx
=
0
if
next_block_idx
==
self
.
blocks_num
:
return
if
self
.
offload_gra
==
"phase"
:
for
phase_idx
in
range
(
self
.
phases_num
):
obj_key
=
(
next_block_idx
,
phase_idx
)
if
self
.
pin_memory_buffer
.
exists
(
obj_key
)
or
(
obj_key
in
self
.
pending_tasks
):
continue
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
phase
=
blocks
[
next_block_idx
].
compute_phases
[
phase_idx
]
priority_key
=
(
next_block_idx
,
phase_idx
)
self
.
disk_task_queue
.
put
((
priority_key
,
(
next_block_idx
,
phase_idx
,
phase
)))
else
:
obj_key
=
next_block_idx
if
self
.
pin_memory_buffer
.
exists
(
obj_key
)
or
(
obj_key
in
self
.
pending_tasks
):
return
with
self
.
task_lock
:
self
.
pending_tasks
[
obj_key
]
=
True
block
=
blocks
[
next_block_idx
]
self
.
disk_task_queue
.
put
((
obj_key
,
(
next_block_idx
,
block
)))
def
_sync_prefetch_block
(
self
,
blocks
):
block_idx
=
0
while
not
self
.
pin_memory_buffer
.
is_nearly_full
():
if
self
.
offload_gra
==
"phase"
:
for
phase_idx
in
range
(
self
.
phases_num
):
phase
=
blocks
[
block_idx
].
compute_phases
[
phase_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
((
block_idx
,
phase_idx
),
phase
)
else
:
block
=
blocks
[
block_idx
]
logger
.
info
(
f
"Synchronous loading: block=
{
block_idx
}
"
)
for
phase
in
block
.
compute_phases
:
phase
.
load_from_disk
()
self
.
pin_memory_buffer
.
push
(
block_idx
,
block
)
block_idx
+=
1
if
block_idx
==
self
.
blocks_num
:
break
def
prefetch_weights_from_disk
(
self
,
blocks
):
if
self
.
initial_prefetch_done
:
return
self
.
_sync_prefetch_block
(
blocks
)
self
.
initial_prefetch_done
=
True
def
prefetch_weights
(
self
,
block_idx
,
blocks
):
obj_key
=
block_idx
if
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
is_loading
=
False
with
self
.
task_lock
:
if
obj_key
in
self
.
pending_tasks
:
is_loading
=
True
if
is_loading
:
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
"
)
else
:
logger
.
info
(
"Not find prefetch block={block_idx} task."
)
logger
.
info
(
"Sync prefetch block={block_idx}."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
for
phase_idx
in
self
.
phases_num
:
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
15
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
block
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
block
.
to_cuda_async
()
self
.
cuda_buffers
[
2
]
=
(
obj_key
,
block
)
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
if
block_idx
<
self
.
offload_blocks_num
:
if
self
.
cuda_buffers
[
1
]
is
not
None
:
old_key
,
old_block
=
self
.
cuda_buffers
[
1
]
if
self
.
pin_memory_buffer
.
exists
(
old_key
):
old_block
.
to_cpu_async
()
self
.
pin_memory_buffer
.
pop
(
old_key
)
def
prefetch_phase
(
self
,
block_idx
,
phase_idx
,
blocks
):
obj_key
=
(
block_idx
,
phase_idx
)
if
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
is_loading
=
False
with
self
.
task_lock
:
if
obj_key
in
self
.
pending_tasks
:
is_loading
=
True
if
is_loading
:
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
(
obj_key
):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
else
:
logger
.
info
(
f
"Not find block=
{
block_idx
}
, phase=
{
phase_idx
}
task."
)
logger
.
info
(
f
"Sync prefetch block=
{
block_idx
}
, phase=
{
phase_idx
}
."
)
self
.
_async_prefetch_block
(
blocks
,
block_idx
)
start_time
=
time
.
time
()
while
not
self
.
pin_memory_buffer
.
exists
((
block_idx
,
phase_idx
)):
time
.
sleep
(
0.001
)
if
time
.
time
()
-
start_time
>
5
:
raise
TimeoutError
(
f
"Load timeout: block=
{
block_idx
}
, phase=
{
phase_idx
}
"
)
with
torch
.
cuda
.
stream
(
self
.
cuda_load_stream
):
phase
=
self
.
pin_memory_buffer
.
get
(
obj_key
)
phase
.
to_cuda_async
()
self
.
cuda_buffers
[
2
]
=
(
obj_key
,
phase
)
with
torch
.
cuda
.
stream
(
self
.
cpu_load_stream
):
if
block_idx
*
self
.
phases_num
+
phase_idx
<
self
.
offload_phases_num
:
if
self
.
cuda_buffers
[
1
]
is
not
None
:
old_key
,
old_phase
=
self
.
cuda_buffers
[
1
]
if
self
.
pin_memory_buffer
.
exists
(
old_key
):
old_phase
.
to_cpu_async
()
self
.
pin_memory_buffer
.
pop
(
old_key
)
def
shutdown
(
self
):
self
.
worker_stop_event
.
set
()
while
not
self
.
disk_task_queue
.
empty
():
try
:
self
.
disk_task_queue
.
get_nowait
()
except
queue
.
Empty
:
continue
for
_
in
self
.
disk_workers
:
self
.
disk_task_queue
.
put
((
0
,
None
))
for
worker
in
self
.
disk_workers
:
worker
.
join
(
timeout
=
5
)
for
worker
in
self
.
release_workers
:
worker
.
join
(
timeout
=
5
)
logger
.
info
(
"All worker threads have been closed"
)
def
clear
(
self
):
self
.
pin_memory_buffer
.
clear
()
self
.
shutdown
()
class
MemoryBuffer
:
def
__init__
(
self
,
max_memory_bytes
=
8
*
(
1024
**
3
)):
self
.
cache
=
OrderedDict
()
self
.
max_mem
=
max_memory_bytes
self
.
used_mem
=
0
self
.
obj_size_map
=
{}
self
.
lock
=
threading
.
Lock
()
self
.
insertion_order
=
[]
self
.
insertion_index
=
0
def
push
(
self
,
key
,
obj
):
with
self
.
lock
:
if
key
in
self
.
cache
:
return
if
hasattr
(
obj
,
"compute_phases"
):
obj_idx
=
key
if
len
(
self
.
obj_size_map
)
==
0
:
_size
=
0
for
phase
in
obj
.
compute_phases
:
_size
+=
phase
.
calculate_size
()
self
.
obj_size_map
[
0
]
=
_size
size
=
self
.
obj_size_map
[
0
]
else
:
_
,
obj_idx
=
key
if
obj_idx
not
in
self
.
obj_size_map
:
self
.
obj_size_map
[
obj_idx
]
=
obj
.
calculate_size
()
size
=
self
.
obj_size_map
[
obj_idx
]
self
.
cache
[
key
]
=
(
size
,
obj
,
self
.
insertion_index
)
self
.
insertion_order
.
append
((
key
,
self
.
insertion_index
))
self
.
insertion_index
+=
1
self
.
used_mem
+=
size
def
_remove_key
(
self
,
key
):
if
key
in
self
.
cache
:
size
,
obj
,
idx
=
self
.
cache
.
pop
(
key
)
try
:
if
hasattr
(
obj
,
"compute_phases"
):
for
phase
in
obj
.
compute_phases
:
phase
.
clear
()
else
:
obj
.
clear
()
except
Exception
as
e
:
logger
.
info
(
f
"Error clearing obj:
{
e
}
"
)
self
.
used_mem
-=
size
self
.
insertion_order
=
[(
k
,
i
)
for
(
k
,
i
)
in
self
.
insertion_order
if
k
!=
key
]
def
get
(
self
,
key
,
default
=
None
):
with
self
.
lock
:
if
key
in
self
.
cache
:
size
,
obj
,
idx
=
self
.
cache
[
key
]
return
obj
return
default
def
exists
(
self
,
key
):
with
self
.
lock
:
return
key
in
self
.
cache
def
pop_front
(
self
):
with
self
.
lock
:
if
not
self
.
insertion_order
:
return
False
front_key
,
_
=
self
.
insertion_order
[
0
]
self
.
_remove_key
(
front_key
)
return
True
def
pop
(
self
,
key
):
with
self
.
lock
:
if
key
in
self
.
cache
:
self
.
_remove_key
(
key
)
return
True
return
False
def
is_nearly_full
(
self
):
with
self
.
lock
:
return
self
.
used_mem
>=
self
.
max_mem
*
0.9
def
get_max_block_index
(
self
):
with
self
.
lock
:
if
not
self
.
cache
:
return
-
1
if
isinstance
(
list
(
self
.
cache
.
keys
())[
-
1
],
tuple
):
return
(
list
(
self
.
cache
.
keys
())[
-
1
][
0
]
+
1
)
%
40
else
:
return
(
list
(
self
.
cache
.
keys
())[
-
1
]
+
1
)
%
40
def
clear
(
self
):
with
self
.
lock
:
for
key
in
list
(
self
.
cache
.
keys
()):
self
.
_remove_key
(
key
)
self
.
insertion_order
=
[]
self
.
insertion_index
=
0
self
.
used_mem
=
0
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
lightx2v/common/ops/attn/flash_attn.py
View file @
74eeb429
...
...
@@ -73,9 +73,10 @@ class FlashAttn3Weight(AttnWeightTemplate):
bs
=
1
elif
len
(
q
.
shape
)
==
4
:
bs
=
q
.
shape
[
0
]
q
=
q
.
reshape
(
-
1
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
])
k
=
k
.
reshape
(
-
1
,
k
.
shape
[
-
2
],
k
.
shape
[
-
1
])
v
=
v
.
reshape
(
-
1
,
v
.
shape
[
-
2
],
v
.
shape
[
-
1
])
if
model_cls
is
not
None
and
model_cls
in
[
"hunyuan_video_1.5"
]:
q
=
q
.
reshape
(
-
1
,
q
.
shape
[
-
2
],
q
.
shape
[
-
1
])
k
=
k
.
reshape
(
-
1
,
k
.
shape
[
-
2
],
k
.
shape
[
-
1
])
v
=
v
.
reshape
(
-
1
,
v
.
shape
[
-
2
],
v
.
shape
[
-
1
])
x
=
flash_attn_varlen_func_v3
(
q
,
k
,
...
...
lightx2v/common/ops/attn/template.py
View file @
74eeb429
...
...
@@ -30,3 +30,6 @@ class AttnWeightTemplate(metaclass=ABCMeta):
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_inde
=
None
):
return
{}
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_inde
=
None
):
pass
lightx2v/common/ops/conv/conv2d.py
View file @
74eeb429
...
...
@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
import
torch
from
lightx2v.utils.registry_factory
import
CONV2D_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
Conv2dWeightTemplate
(
metaclass
=
ABCMeta
):
...
...
@@ -34,8 +35,8 @@ class Conv2dWeight(Conv2dWeightTemplate):
super
().
__init__
(
weight_name
,
bias_name
,
stride
,
padding
,
dilation
,
groups
)
def
load
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
(
)
if
self
.
bias_name
is
not
None
else
None
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
to
(
AI_DEVICE
)
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
input_tensor
=
torch
.
nn
.
functional
.
conv2d
(
input_tensor
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
stride
=
self
.
stride
,
padding
=
self
.
padding
,
dilation
=
self
.
dilation
,
groups
=
self
.
groups
)
...
...
@@ -47,9 +48,9 @@ class Conv2dWeight(Conv2dWeightTemplate):
self
.
bias
=
self
.
bias
.
cpu
(
non_blocking
=
non_blocking
)
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
...
...
@@ -58,10 +59,3 @@ class Conv2dWeight(Conv2dWeightTemplate):
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
return
destination
def
clear
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
lightx2v/common/ops/conv/conv3d.py
View file @
74eeb429
...
...
@@ -3,6 +3,7 @@ from abc import ABCMeta, abstractmethod
import
torch
from
lightx2v.utils.registry_factory
import
CONV3D_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
Conv3dWeightTemplate
(
metaclass
=
ABCMeta
):
...
...
@@ -70,9 +71,9 @@ class Conv3dWeight(Conv3dWeightTemplate):
return
input_tensor
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
...
...
@@ -91,10 +92,3 @@ class Conv3dWeight(Conv3dWeightTemplate):
if
self
.
bias_name
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
# .cpu().detach().clone()
return
destination
def
clear
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
lightx2v/common/ops/embedding/embedding_weight.py
View file @
74eeb429
...
...
@@ -5,12 +5,14 @@ import torch
import
torch.nn.functional
as
F
from
lightx2v.utils.registry_factory
import
EMBEDDING_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
EmbeddingWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
weight_name
=
weight_name
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
...
...
@@ -19,7 +21,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
(
)
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
...
...
@@ -32,7 +34,7 @@ class EmbeddingWeightTemplate(metaclass=ABCMeta):
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
...
...
lightx2v/common/ops/mm/mm_weight.py
View file @
74eeb429
...
...
@@ -9,6 +9,7 @@ from lightx2v.utils.ggml_tensor import dequantize_tensor as gguf_dequantize_tens
from
lightx2v.utils.global_paras
import
CALIB
from
lightx2v.utils.quant_utils
import
FloatQuantizer
,
IntegerQuantizer
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
from
lightx2v_kernel.gemm
import
(
...
...
@@ -69,10 +70,11 @@ except ImportError:
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
...
...
@@ -90,11 +92,11 @@ class MMWeightTemplate(metaclass=ABCMeta):
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight_scale
=
self
.
pin_weight_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
...
...
@@ -113,44 +115,63 @@ class MMWeightTemplate(metaclass=ABCMeta):
@
MM_WEIGHT_REGISTER
(
"Default"
)
class
MMWeight
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
if
self
.
bias_name
is
not
None
:
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
_load_cuda_buffers
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffers
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_get_source_tensor
(
self
,
source_name
,
weight_dict
=
None
):
if
self
.
lazy_load
:
return
self
.
lazy_load_file
.
get_tensor
(
source_name
)
return
weight_dict
[
source_name
]
def
_create_pin_tensor
(
self
,
tensor
,
transpose
=
False
):
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
=
pin_tensor
.
copy_
(
tensor
)
if
transpose
:
pin_tensor
=
pin_tensor
.
t
()
del
tensor
return
pin_tensor
def
_load_cuda_buffers
(
self
,
weight_dict
):
self
.
weight_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
weight_name
,
weight_dict
).
t
().
to
(
AI_DEVICE
)
if
self
.
bias_name
is
not
None
:
self
.
bias_cuda_buffer
=
self
.
_get_source_tensor
(
self
.
bias_name
,
weight_dict
).
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
]).
t
()
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
,
transpose
=
True
)
if
self
.
bias_name
is
not
None
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
bias_tensor
=
weight_dict
[
self
.
bias_name
]
self
.
pin_bias
=
self
.
_create_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
...
...
@@ -169,6 +190,28 @@ class MMWeight(MMWeightTemplate):
destination
[
self
.
bias_name
]
=
self
.
pin_bias
if
hasattr
(
self
,
"pin_bias"
)
else
self
.
bias
return
destination
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
...
...
@@ -195,19 +238,20 @@ class MMWeight(MMWeightTemplate):
@
MM_WEIGHT_REGISTER
(
"Default-Force-FP32"
)
class
MMWeightForceFP32
(
MMWeight
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
super
().
load
(
weight_dict
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float32
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
if
not
self
.
lazy_load
:
super
().
load
(
weight_dict
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float32
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
class
MMWeightQuantTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
weight_scale_name
=
self
.
weight_name
.
removesuffix
(
".weight"
)
+
".weight_scale"
self
.
load_func
=
None
self
.
weight_need_transpose
=
True
...
...
@@ -215,87 +259,133 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
infer_dtype
=
GET_DTYPE
()
self
.
bias_force_fp32
=
False
# =========================
# weight load functions
# =========================
def
load
(
self
,
weight_dict
):
self
.
load_quantized
(
weight_dict
)
if
self
.
weight_need_transpose
:
if
hasattr
(
self
,
"weight"
)
and
self
.
weight
is
not
None
:
self
.
weight
=
self
.
weight
.
t
()
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
self
.
pin_weight
=
self
.
pin_weight
.
t
()
if
hasattr
(
self
,
"weight_cuda_buffer"
)
and
self
.
weight_cuda_buffer
is
not
None
:
self
.
weight_cuda_buffer
=
self
.
weight_cuda_buffer
.
t
()
def
load_from_disk
(
self
):
# Need Rewrite
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
pin_memory
()
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
().
pin_memory
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
).
pin_memory
()
def
load_quantized
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffers
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffers
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
_load_default_tensors
(
weight_dict
)
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
def
_load_cuda_buffers
(
self
,
weight_dict
):
source
=
self
.
lazy_load_file
if
self
.
lazy_load
else
weight_dict
self
.
weight_cuda_buffer
,
self
.
weight_scale_cuda_buffer
=
self
.
_get_cuda_tensor_pair
(
source
,
self
.
lazy_load
)
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
self
.
lazy_load
)
def
load
(
self
,
weight_dict
):
def
_get_cuda_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
weight
=
source
.
get_tensor
(
self
.
weight_name
).
to
(
AI_DEVICE
)
scale
=
source
.
get_tensor
(
self
.
weight_scale_name
).
float
().
to
(
AI_DEVICE
)
else
:
weight
=
source
[
self
.
weight_name
].
to
(
AI_DEVICE
)
scale
=
source
[
self
.
weight_scale_name
].
float
().
to
(
AI_DEVICE
)
return
weight
,
scale
def
_get_cuda_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
return
None
if
is_lazy
:
bias
=
source
.
get_tensor
(
self
.
bias_name
)
dtype
=
self
.
infer_dtype
else
:
bias
=
source
[
self
.
bias_name
]
dtype
=
bias
.
dtype
if
self
.
bias_force_fp32
:
bias
=
bias
.
to
(
torch
.
float32
)
else
:
bias
=
bias
.
to
(
dtype
)
return
bias
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
self
.
pin_weight
,
self
.
pin_weight_scale
=
self
.
_get_cpu_pin_tensor_pair
(
self
.
lazy_load_file
,
is_lazy
=
True
)
self
.
pin_bias
=
self
.
_get_cpu_pin_bias_tensor
(
self
.
lazy_load_file
,
is_lazy
=
True
)
self
.
bias
=
None
def
_get_cpu_pin_tensor_pair
(
self
,
source
,
is_lazy
):
if
is_lazy
:
weight_tensor
=
source
.
get_tensor
(
self
.
weight_name
)
scale_tensor
=
source
.
get_tensor
(
self
.
weight_scale_name
)
scale_dtype
=
torch
.
float
else
:
weight_tensor
=
source
[
self
.
weight_name
]
scale_tensor
=
source
[
self
.
weight_scale_name
]
scale_dtype
=
torch
.
float
pin_weight
=
self
.
_create_pin_tensor
(
weight_tensor
)
pin_scale
=
self
.
_create_pin_tensor
(
scale_tensor
,
scale_dtype
)
return
pin_weight
,
pin_scale
def
_get_cpu_pin_bias_tensor
(
self
,
source
,
is_lazy
):
if
self
.
bias_name
is
None
:
return
None
if
is_lazy
:
bias_tensor
=
source
.
get_tensor
(
self
.
bias_name
)
if
not
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
self
.
infer_dtype
)
else
:
bias_tensor
=
source
[
self
.
bias_name
]
if
self
.
bias_force_fp32
:
bias_tensor
=
bias_tensor
.
to
(
torch
.
float32
)
return
self
.
_create_pin_tensor
(
bias_tensor
)
def
_create_pin_tensor
(
self
,
tensor
,
dtype
=
None
):
dtype
=
dtype
or
tensor
.
dtype
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
self
.
load_func
(
weight_dict
)
if
self
.
weight_need_transpose
:
if
hasattr
(
self
,
"weight"
):
self
.
weight
=
self
.
weight
.
t
()
if
hasattr
(
self
,
"pin_weight"
):
self
.
pin_weight
=
self
.
pin_weight
.
t
()
if
hasattr
(
self
,
"weight_cuda_buffer"
):
self
.
weight_cuda_buffer
=
self
.
weight_cuda_buffer
.
t
()
def
clear
(
self
):
attrs
=
[
"weight"
,
"weight_scale"
,
"bias"
,
"pin_weight"
,
"pin_weight_scale"
,
"pin_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
self
.
weight
,
self
.
weight_scale
,
self
.
pin_weight
,
self
.
pin_weight_scale
=
self
.
_get_device_tensor_pair
(
weight_dict
)
self
.
_load_default_bias
(
weight_dict
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_quantized
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
# move to cuda buffer
self
.
weight
_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
(
)
self
.
weight_scale_cuda_buffer
=
weight_dict
[
self
.
weight_scale
_name
].
float
().
cuda
()
def
_get_device_tensor_pair
(
self
,
source
):
device
=
source
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
pin_
weight
,
pin_scale
=
self
.
_get_cpu_pin_tensor_pair
(
source
,
is_lazy
=
False
)
return
None
,
None
,
pin_
weight
,
pin
_scale
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
return
source
[
self
.
weight_name
],
source
[
self
.
weight_scale_name
].
float
(),
None
,
None
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
torch
.
float
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
def
_load_default_bias
(
self
,
source
):
if
self
.
bias_name
is
None
:
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
bias_cuda_buffer
=
None
return
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
# move to cuda buffer
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
if
self
.
create_cuda_buffer
:
self
.
bias_cuda_buffer
=
self
.
_get_cuda_bias_tensor
(
source
,
is_lazy
=
False
)
self
.
bias
=
None
self
.
pin_bias
=
None
else
:
bias_tensor
=
source
[
self
.
bias_name
].
float
()
if
self
.
bias_force_fp32
else
source
[
self
.
bias_name
]
device
=
bias_tensor
.
device
if
device
.
type
==
"cpu"
:
self
.
pin_bias
=
self
.
_get_cpu_pin_bias_tensor
(
source
,
is_lazy
=
False
)
self
.
bias
=
None
else
:
self
.
bias
=
bias_tensor
self
.
pin_bias
=
None
def
load_fp8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
...
@@ -320,7 +410,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
load_mxfp4
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
).
to
(
torch
.
bfloat16
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp4_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
...
...
@@ -343,7 +433,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
load_mxfp6
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
).
to
(
torch
.
bfloat16
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp6_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
...
...
@@ -366,7 +456,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
load_mxfp8
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
(
).
to
(
torch
.
bfloat16
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
AI_DEVICE
).
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp8_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
...
...
@@ -424,19 +514,16 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
self
.
bias_name
is
not
None
:
if
self
.
create_cuda_buffer
:
# move to cuda buffer
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
to
(
AI_DEVICE
)
else
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
if
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
...
...
@@ -548,6 +635,36 @@ class MMWeightQuantTemplate(MMWeightTemplate):
else
:
self
.
bias
=
None
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
self
.
weight_scale_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_scale_name
,
count
=
1
)
if
self
.
weight_need_transpose
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
t
()
else
:
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
weight_scale_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
)
self
.
pin_weight_scale
=
self
.
pin_weight_scale
.
copy_
(
weight_scale_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
MM_WEIGHT_REGISTER
(
"fp8-vllm"
)
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeightQuantTemplate
):
...
...
@@ -560,8 +677,8 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightQuantTemplate):
Kernel: vllm
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
...
...
@@ -595,8 +712,8 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
Kernel: vllm
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
...
...
@@ -605,7 +722,7 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
zeros
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
...
...
@@ -629,8 +746,8 @@ class MMWeightWmxfp4Amxfp4dynamic(MMWeightQuantTemplate):
Act: mxfp4
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp4
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp4
...
...
@@ -656,8 +773,8 @@ class MMWeightWmxfp6Amxfp8dynamic(MMWeightQuantTemplate):
Act: mxfp8
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp6
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
...
...
@@ -683,8 +800,8 @@ class MMWeightWmxfp8Amxfp8dynamic(MMWeightQuantTemplate):
Act: mxfp8
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_mxfp8
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
...
...
@@ -710,8 +827,8 @@ class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
Act: nvfp4
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_nvfp4
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_nvfp4
...
...
@@ -722,13 +839,13 @@ class MMWeightWnvfp4Anvfp4dynamic(MMWeightQuantTemplate):
return
output_tensor
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
pin_alpha
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight_scale
=
self
.
pin_weight_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
pin_alpha
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
...
...
@@ -758,8 +875,8 @@ class MMCalibNvfp4(MMWeight):
absmax: torch.max(torch.abs(input_tensor))
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
running_absmax
=
None
self
.
count
=
0
self
.
decay
=
0.9
...
...
@@ -794,11 +911,12 @@ class MMWeightWfp8channelAfp8channeldynamicQ8F(MMWeightQuantTemplate):
Kernel: Q8F
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
self
.
bias_force_fp32
=
True
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
...
...
@@ -824,8 +942,8 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightQuantTemplate):
Kernel: Q8F
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
...
...
@@ -855,8 +973,8 @@ class MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl(MMWeightQuant
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perblock128_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannelgroup128_sym_sgl
...
...
@@ -889,8 +1007,8 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
Kernel: Sgl-kernel
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_sgl
...
...
@@ -903,7 +1021,7 @@ class MMWeightWfp8channelAfp8channeldynamicSgl(MMWeightQuantTemplate):
input_tensor_scale
,
self
.
weight_scale
,
self
.
infer_dtype
,
bias
=
self
.
bias
,
self
.
bias
if
self
.
bias
is
not
None
else
None
,
)
return
output_tensor
...
...
@@ -919,8 +1037,8 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
...
...
@@ -944,7 +1062,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
@
MM_WEIGHT_REGISTER
(
"int8-torchao"
)
class
MMWeightWint8channelAint8channeldynamic
SglActVllm
(
MMWeightQuantTemplate
):
class
MMWeightWint8channelAint8channeldynamic
Torchao
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Torchao
...
...
@@ -954,8 +1072,8 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
Kernel: Torchao
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_torchao
...
...
@@ -971,33 +1089,34 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
class
MMWeightGGUFTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
load
(
self
,
weight_dict
):
assert
not
self
.
create_cuda_buffer
,
"GGUF Unsupported offload block"
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
not
self
.
lazy_load
:
assert
not
self
.
create_cuda_buffer
,
"GGUF Unsupported offload block"
self
.
weight
=
weight_dict
[
self
.
weight_name
]
weight_shape
=
self
.
weight
.
shape
weight_dtype
=
self
.
weight
.
dtype
weight_shape
=
self
.
weight
.
shape
weight_dtype
=
self
.
weight
.
dtype
if
isinstance
(
self
.
weight
,
GGMLTensor
):
self
.
pin_weight
=
GGMLTensor
.
empty_pinned
(
weight_shape
,
orig_shape
=
self
.
weight
.
orig_shape
,
dtype
=
weight_dtype
,
gguf_type
=
self
.
weight
.
gguf_type
)
self
.
pin_weight
.
copy_from
(
self
.
weight
)
else
:
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
if
isinstance
(
self
.
weight
,
GGMLTensor
):
self
.
pin_weight
=
GGMLTensor
.
empty_pinned
(
weight_shape
,
orig_shape
=
self
.
weight
.
orig_shape
,
dtype
=
weight_dtype
,
gguf_type
=
self
.
weight
.
gguf_type
)
self
.
pin_weight
.
copy_from
(
self
.
weight
)
else
:
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
isinstance
(
self
.
bias
,
GGMLTensor
):
self
.
pin_bias
=
GGMLTensor
.
empty_pinned
(
self
.
bias
.
shape
,
orig_shape
=
self
.
bias
.
orig_shape
,
dtype
=
self
.
bias
.
dtype
,
gguf_type
=
self
.
bias
.
gguf_type
)
self
.
pin_bias
.
copy_from
(
self
.
bias
)
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
isinstance
(
self
.
bias
,
GGMLTensor
):
self
.
pin_bias
=
GGMLTensor
.
empty_pinned
(
self
.
bias
.
shape
,
orig_shape
=
self
.
bias
.
orig_shape
,
dtype
=
self
.
bias
.
dtype
,
gguf_type
=
self
.
bias
.
gguf_type
)
self
.
pin_bias
.
copy_from
(
self
.
bias
)
else
:
self
.
pin_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
pin_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
None
self
.
bias
=
None
def
load_state_dict
(
self
,
destination
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
...
...
@@ -1035,9 +1154,7 @@ class MMWeightGGUFTemplate(MMWeightTemplate):
if
tensor
is
None
:
return
device
=
tensor
.
device
weight
=
gguf_dequantize_tensor
(
tensor
,
dtype
)
# prevent propagating custom tensor class
if
isinstance
(
weight
,
GGMLTensor
):
weight
=
torch
.
Tensor
(
weight
)
...
...
@@ -1135,8 +1252,8 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
Kernel: Marlin
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
def
__init__
(
self
,
weight_name
,
bias_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
)
self
.
load_func
=
self
.
load_quantized
def
load
(
self
,
weight_dict
):
...
...
lightx2v/common/ops/norm/layer_norm_weight.py
View file @
74eeb429
...
...
@@ -5,16 +5,18 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
LN_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.triton_ops
import
norm_infer
class
LNWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
bias_name
=
bias_name
self
.
eps
=
eps
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
...
...
@@ -23,53 +25,71 @@ class LNWeightTemplate(metaclass=ABCMeta):
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
if
self
.
weight_name
is
not
None
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
().
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias_cuda_buffer
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffers
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffers
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
and
self
.
weight_name
is
not
None
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_cpu_pin_tensor
(
weight_tensor
)
bias_tensor
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
self
.
pin_bias
=
self
.
_create_cpu_pin_tensor
(
bias_tensor
)
if
bias_tensor
is
not
None
else
None
self
.
bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
if
self
.
weight_name
is
not
None
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
if
self
.
bias_name
is
not
None
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
else
:
self
.
bias
=
None
else
:
self
.
weight
=
None
self
.
bias
=
None
def
_calculate_size
(
self
):
if
self
.
weight
is
None
:
return
0
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
def
clear
(
self
):
attrs
=
[
"weight"
,
"bias"
,
"pinned_weight"
,
"pinned_bias"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
bias
=
weight_dict
[
self
.
bias_name
]
if
self
.
bias_name
is
not
None
else
None
else
:
self
.
weight
=
None
self
.
bias
=
None
def
_get_tensor
(
self
,
name
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
name
is
None
:
return
None
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
name
]
return
tensor
def
_create_cpu_pin_tensor
(
self
,
tensor
):
if
tensor
is
None
:
return
None
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_cuda_buffers
(
self
,
weight_dict
):
weight_tensor
=
self
.
_get_tensor
(
self
.
weight_name
,
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
if
weight_tensor
is
not
None
:
self
.
weight_cuda_buffer
=
weight_tensor
.
to
(
AI_DEVICE
)
bias_tensor
=
self
.
_get_tensor
(
self
.
bias_name
,
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
if
bias_tensor
is
not
None
:
self
.
bias_cuda_buffer
=
bias_tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffers
(
self
):
weight_tensor
=
self
.
_get_tensor
(
self
.
weight_name
,
use_infer_dtype
=
True
)
if
weight_tensor
is
not
None
:
self
.
pin_weight
=
self
.
_create_cpu_pin_tensor
(
weight_tensor
)
else
:
self
.
weight
=
None
bias_tensor
=
self
.
_get_tensor
(
self
.
bias_name
,
use_infer_dtype
=
True
)
if
bias_tensor
is
not
None
:
self
.
pin_bias
=
self
.
_create_cpu_pin_tensor
(
bias_tensor
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
@
abstractmethod
def
apply
(
self
,
input_tensor
):
...
...
@@ -81,11 +101,11 @@ class LNWeightTemplate(metaclass=ABCMeta):
def
to_cuda
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
)
and
self
.
pin_weight
is
not
None
:
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
else
:
self
.
weight
=
None
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
pin_bias
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
else
:
self
.
bias
=
None
...
...
@@ -129,28 +149,33 @@ class LNWeightTemplate(metaclass=ABCMeta):
else
:
self
.
bias
=
None
@
LN_WEIGHT_REGISTER
(
"Default"
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
weight_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
()
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
(
)
if
self
.
is_post_adapter
:
self
.
weight
_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
else
:
self
.
weight
=
None
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
if
self
.
bias_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
()).
pin_memory
()
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
bias_name
,
count
=
1
)
else
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
())
else
:
self
.
bias
=
None
self
.
bias_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
bias_name
,
count
=
1
)
bias_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
self
.
infer_dtype
)
self
.
pin_bias
.
copy_
(
bias_tensor
)
del
bias_tensor
@
LN_WEIGHT_REGISTER
(
"Default"
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
if
self
.
sensitive_layer_dtype
!=
self
.
infer_dtype
:
...
...
@@ -169,25 +194,8 @@ class LNWeight(LNWeightTemplate):
@
LN_WEIGHT_REGISTER
(
"Triton"
)
class
LNWeight
(
LNWeightTemplate
):
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
if
self
.
weight_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
else
:
self
.
weight
=
None
if
self
.
bias_name
is
not
None
:
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
bias
=
self
.
lazy_load_file
.
get_tensor
(
self
.
bias_name
).
to
(
GET_DTYPE
())
else
:
self
.
bias
=
None
def
__init__
(
self
,
weight_name
=
None
,
bias_name
=
None
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
bias_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
input_tensor
=
norm_infer
(
input_tensor
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
...
...
lightx2v/common/ops/norm/rms_norm_weight.py
View file @
74eeb429
...
...
@@ -5,6 +5,7 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
RMS_WEIGHT_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
try
:
import
sgl_kernel
...
...
@@ -13,10 +14,11 @@ except ImportError:
class
RMSWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
self
.
weight_name
=
weight_name
self
.
eps
=
eps
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
...
...
@@ -25,26 +27,45 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
config
=
{}
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffer
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffer
()
else
:
self
.
_load_default_tensors
(
weight_dict
)
def
_load_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
self
.
weight_cuda_buffer
=
weight_dict
[
self
.
weight_name
].
cuda
()
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_tensor
=
weight_dict
[
self
.
weight_name
]
self
.
pin_weight
=
self
.
_create_cpu_pin_weight
(
weight_tensor
)
del
weight_dict
[
self
.
weight_name
]
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
clear
(
self
):
attrs
=
[
"weight"
,
"pinned_weight"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
self
.
weight
=
weight_dict
[
self
.
weight_name
]
def
_get_weight_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
self
.
weight_name
]
return
tensor
def
_create_cpu_pin_weight
(
self
,
tensor
):
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_cuda_buffer
(
self
,
weight_dict
):
weight_tensor
=
self
.
_get_weight_tensor
(
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
self
.
weight_cuda_buffer
=
weight_tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffer
(
self
):
weight_tensor
=
self
.
_get_weight_tensor
(
use_infer_dtype
=
True
)
self
.
pin_weight
=
self
.
_create_cpu_pin_weight
(
weight_tensor
)
@
abstractmethod
def
apply
(
self
,
input_tensor
):
...
...
@@ -55,7 +76,7 @@ class RMSWeightTemplate(metaclass=ABCMeta):
self
.
config
=
config
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
self
.
weight
=
self
.
pin_weight
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
...
...
@@ -63,31 +84,6 @@ class RMSWeightTemplate(metaclass=ABCMeta):
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
def
_calculate_size
(
self
):
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
@
RMS_WEIGHT_REGISTER
(
"Default"
)
class
RMSWeight
(
RMSWeightTemplate
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
apply
(
self
,
input_tensor
):
if
GET_SENSITIVE_DTYPE
()
!=
GET_DTYPE
():
input_tensor
=
self
.
_norm
(
input_tensor
).
type_as
(
input_tensor
)
*
self
.
weight
else
:
input_tensor
=
self
.
_norm
(
input_tensor
.
float
()).
type_as
(
input_tensor
)
*
self
.
weight
return
input_tensor
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
...
...
@@ -106,6 +102,32 @@ class RMSWeight(RMSWeightTemplate):
return
self
.
weight
=
self
.
weight_cuda_buffer
.
copy_
(
destination
[
weight_name
],
non_blocking
=
True
)
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
weight_name
,
count
=
1
)
else
:
self
.
weight_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
weight_name
,
count
=
1
)
weight_tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
self
.
infer_dtype
)
self
.
pin_weight
=
self
.
pin_weight
.
copy_
(
weight_tensor
)
del
weight_tensor
@
RMS_WEIGHT_REGISTER
(
"Default"
)
class
RMSWeight
(
RMSWeightTemplate
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
apply
(
self
,
input_tensor
):
if
GET_SENSITIVE_DTYPE
()
!=
GET_DTYPE
():
input_tensor
=
self
.
_norm
(
input_tensor
).
type_as
(
input_tensor
)
*
self
.
weight
else
:
input_tensor
=
self
.
_norm
(
input_tensor
.
float
()).
type_as
(
input_tensor
)
*
self
.
weight
return
input_tensor
@
RMS_WEIGHT_REGISTER
(
"sgl-kernel"
)
class
RMSWeightSgl
(
RMSWeight
):
...
...
@@ -113,18 +135,13 @@ class RMSWeightSgl(RMSWeight):
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
,
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
()).
pin_memory
()
else
:
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
to
(
GET_DTYPE
())
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
if
sgl_kernel
is
not
None
and
self
.
sensitive_layer_dtype
==
self
.
infer_dtype
:
...
...
@@ -146,8 +163,8 @@ class RMSWeightSgl(RMSWeight):
@
RMS_WEIGHT_REGISTER
(
"fp32_variance"
)
class
RMSWeightFP32
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
apply
(
self
,
input_tensor
):
input_dtype
=
input_tensor
.
dtype
...
...
@@ -165,8 +182,8 @@ class RMSWeightFP32(RMSWeight):
@
RMS_WEIGHT_REGISTER
(
"self_forcing"
)
class
RMSWeightSF
(
RMSWeight
):
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
__init__
(
self
,
weight_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
,
eps
=
1e-6
):
super
().
__init__
(
weight_name
,
create_cuda_buffer
,
create_cpu_buffer
,
lazy_load
,
lazy_load_file
,
is_post_adapter
,
eps
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
dim
=-
1
,
keepdim
=
True
)
+
self
.
eps
)
...
...
lightx2v/common/ops/tensor/tensor.py
View file @
74eeb429
...
...
@@ -4,52 +4,64 @@ import torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.registry_factory
import
TENSOR_REGISTER
from
lightx2v_platform.base.global_var
import
AI_DEVICE
@
TENSOR_REGISTER
(
"Default"
)
class
DefaultTensor
:
def
__init__
(
self
,
tensor_name
,
create_cuda_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
def
__init__
(
self
,
tensor_name
,
create_cuda_buffer
=
False
,
create_cpu_buffer
=
False
,
lazy_load
=
False
,
lazy_load_file
=
None
,
is_post_adapter
=
False
):
self
.
tensor_name
=
tensor_name
self
.
lazy_load
=
lazy_load
self
.
lazy_load_file
=
lazy_load_file
self
.
is_post_adapter
=
is_post_adapter
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
infer_dtype
=
GET_DTYPE
()
self
.
sensitive_layer_dtype
=
GET_SENSITIVE_DTYPE
()
def
load_from_disk
(
self
):
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
).
pin_memory
()
def
load
(
self
,
weight_dict
):
if
self
.
create_cuda_buffer
:
self
.
_load_cuda_buffer
(
weight_dict
)
elif
self
.
create_cpu_buffer
:
self
.
_load_cpu_pin_buffer
()
else
:
self
.
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
_load_default_tensors
(
weight_dict
)
def
load
(
self
,
weight_dict
):
def
_
load
_default_tensors
(
self
,
weight_dict
):
if
not
self
.
lazy_load
:
if
self
.
create_cuda_buffer
:
self
.
tensor_cuda_buffer
=
weight_dict
[
self
.
tensor_name
].
cuda
()
device
=
weight_dict
[
self
.
tensor_name
].
device
if
device
.
type
==
"cpu"
:
tensor
=
weight_dict
[
self
.
tensor_name
]
self
.
pin_tensor
=
self
.
_create_cpu_pin_tensor
(
tensor
)
del
weight_dict
[
self
.
tensor_name
]
else
:
device
=
weight_dict
[
self
.
tensor_name
].
device
if
device
.
type
==
"cpu"
:
tensor_shape
=
weight_dict
[
self
.
tensor_name
].
shape
tensor_dtype
=
weight_dict
[
self
.
tensor_name
].
dtype
self
.
pin_tensor
=
torch
.
empty
(
tensor_shape
,
pin_memory
=
True
,
dtype
=
tensor_dtype
)
self
.
pin_tensor
.
copy_
(
weight_dict
[
self
.
tensor_name
])
del
weight_dict
[
self
.
tensor_name
]
else
:
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
clear
(
self
):
attrs
=
[
"tensor"
,
"pinned_tensor"
]
for
attr
in
attrs
:
if
hasattr
(
self
,
attr
):
delattr
(
self
,
attr
)
setattr
(
self
,
attr
,
None
)
def
_calculate_size
(
self
):
return
self
.
tensor
.
numel
()
*
self
.
tensor
.
element_size
()
self
.
tensor
=
weight_dict
[
self
.
tensor_name
]
def
_get_tensor
(
self
,
weight_dict
=
None
,
use_infer_dtype
=
False
):
if
self
.
lazy_load
:
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
)
if
use_infer_dtype
:
tensor
=
tensor
.
to
(
self
.
infer_dtype
)
else
:
tensor
=
weight_dict
[
self
.
tensor_name
]
return
tensor
def
_create_cpu_pin_tensor
(
self
,
tensor
):
pin_tensor
=
torch
.
empty
(
tensor
.
shape
,
pin_memory
=
True
,
dtype
=
tensor
.
dtype
)
pin_tensor
.
copy_
(
tensor
)
del
tensor
return
pin_tensor
def
_load_cuda_buffer
(
self
,
weight_dict
):
tensor
=
self
.
_get_tensor
(
weight_dict
,
use_infer_dtype
=
self
.
lazy_load
)
self
.
tensor_cuda_buffer
=
tensor
.
to
(
AI_DEVICE
)
def
_load_cpu_pin_buffer
(
self
):
tensor
=
self
.
_get_tensor
(
use_infer_dtype
=
True
)
self
.
pin_tensor
=
self
.
_create_cpu_pin_tensor
(
tensor
)
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
tensor
=
self
.
pin_tensor
.
cuda
(
non_blocking
=
non_blocking
)
self
.
tensor
=
self
.
pin_tensor
.
to
(
AI_DEVICE
,
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_tensor"
):
...
...
@@ -69,8 +81,18 @@ class DefaultTensor:
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
if
tensor_name
not
in
destination
:
self
.
tensor
=
None
return
self
.
tensor
=
self
.
tensor_cuda_buffer
.
copy_
(
destination
[
tensor_name
],
non_blocking
=
True
)
def
load_state_dict_from_disk
(
self
,
block_index
,
adapter_block_index
=
None
):
if
self
.
is_post_adapter
:
assert
adapter_block_index
is
not
None
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
adapter_block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
else
:
self
.
tensor_name
=
re
.
sub
(
r
"\.\d+"
,
lambda
m
:
f
".
{
block_index
}
"
,
self
.
tensor_name
,
count
=
1
)
tensor
=
self
.
lazy_load_file
.
get_tensor
(
self
.
tensor_name
).
to
(
self
.
infer_dtype
)
self
.
pin_tensor
=
self
.
pin_tensor
.
copy_
(
tensor
)
del
tensor
lightx2v/models/input_encoders/hf/qwen25/qwen25_vlforconditionalgeneration.py
100644 → 100755
View file @
74eeb429
...
...
@@ -62,17 +62,13 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
self
.
VAE_IMAGE_SIZE
=
1024
*
1024
self
.
cpu_offload
=
config
.
get
(
"cpu_offload"
,
False
)
if
self
.
cpu_offload
:
self
.
device
=
torch
.
device
(
"cpu"
)
else
:
self
.
device
=
torch
.
device
(
AI_DEVICE
)
self
.
dtype
=
torch
.
bfloat16
self
.
load
()
def
load
(
self
):
self
.
text_encoder
=
Qwen2_5_VLForConditionalGeneration
.
from_pretrained
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"text_encoder"
),
torch_dtype
=
torch
.
bfloat16
)
if
not
self
.
cpu_offload
:
self
.
text_encoder
=
self
.
text_encoder
.
to
(
self
.
device
)
self
.
text_encoder
=
self
.
text_encoder
.
to
(
AI_DEVICE
)
self
.
tokenizer
=
Qwen2Tokenizer
.
from_pretrained
(
os
.
path
.
join
(
self
.
config
[
"model_path"
],
"tokenizer"
))
if
self
.
config
[
"task"
]
==
"i2i"
:
...
...
@@ -98,7 +94,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
@
torch
.
no_grad
()
def
infer
(
self
,
text
,
image_list
=
None
):
if
self
.
cpu_offload
:
self
.
text_encoder
.
to
(
self
.
device
)
self
.
text_encoder
.
to
(
AI_DEVICE
)
if
image_list
is
not
None
:
condition_image_list
=
[]
...
...
@@ -133,7 +129,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
images
=
condition_image_list
,
padding
=
True
,
return_tensors
=
"pt"
,
).
to
(
torch
.
device
(
self
.
device
)
)
).
to
(
AI_DEVICE
)
encoder_hidden_states
=
self
.
text_encoder
(
input_ids
=
model_inputs
.
input_ids
,
...
...
@@ -156,7 +152,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
txt
=
[
template
.
format
(
e
)
for
e
in
text
]
image_info
=
{}
model_inputs
=
self
.
tokenizer
(
txt
,
max_length
=
self
.
tokenizer_max_length
+
drop_idx
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
).
to
(
self
.
device
)
model_inputs
=
self
.
tokenizer
(
txt
,
max_length
=
self
.
tokenizer_max_length
+
drop_idx
,
padding
=
True
,
truncation
=
True
,
return_tensors
=
"pt"
).
to
(
AI_DEVICE
)
encoder_hidden_states
=
self
.
text_encoder
(
input_ids
=
model_inputs
.
input_ids
,
attention_mask
=
model_inputs
.
attention_mask
,
...
...
@@ -172,7 +168,7 @@ class Qwen25_VLForConditionalGeneration_TextEncoder:
prompt_embeds
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
max_seq_len
-
u
.
size
(
0
),
u
.
size
(
1
))])
for
u
in
split_hidden_states
])
encoder_attention_mask
=
torch
.
stack
([
torch
.
cat
([
u
,
u
.
new_zeros
(
max_seq_len
-
u
.
size
(
0
))])
for
u
in
attn_mask_list
])
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
self
.
device
)
prompt_embeds
=
prompt_embeds
.
to
(
dtype
=
self
.
dtype
,
device
=
AI_DEVICE
)
prompt_embeds_mask
=
encoder_attention_mask
_
,
seq_len
,
_
=
prompt_embeds
.
shape
...
...
lightx2v/models/input_encoders/hf/wan/t5/model.py
View file @
74eeb429
...
...
@@ -515,7 +515,7 @@ class T5Encoder(nn.Module):
e
=
pos_bias
else
:
lq
,
lk
=
x
.
size
(
1
),
x
.
size
(
1
)
rel_pos
=
torch
.
arange
(
lk
,
device
=
"cuda"
).
unsqueeze
(
0
)
-
torch
.
arange
(
lq
,
device
=
"cuda"
).
unsqueeze
(
1
)
rel_pos
=
torch
.
arange
(
lk
,
device
=
AI_DEVICE
).
unsqueeze
(
0
)
-
torch
.
arange
(
lq
,
device
=
AI_DEVICE
).
unsqueeze
(
1
)
num_buckets
=
block
.
pos_embedding
.
weight
.
shape
[
0
]
//
2
rel_buckets
=
(
rel_pos
>
0
).
long
()
*
num_buckets
rel_pos
=
torch
.
abs
(
rel_pos
)
...
...
@@ -532,28 +532,21 @@ class T5Encoder(nn.Module):
return
x
def
forward_with_offload
(
self
,
ids
,
mask
=
None
):
self
.
token_embedding
=
self
.
token_embedding
.
to
(
"cuda"
)
self
.
pos_embedding
=
self
.
pos_embedding
.
to
(
"cuda"
)
if
self
.
pos_embedding
is
not
None
else
None
self
.
token_embedding
=
self
.
token_embedding
.
to
(
AI_DEVICE
)
self
.
pos_embedding
=
self
.
pos_embedding
.
to
(
AI_DEVICE
)
if
self
.
pos_embedding
is
not
None
else
None
x
=
self
.
token_embedding
(
ids
)
x
=
self
.
dropout
(
x
)
e
=
self
.
pos_embedding
(
x
.
size
(
1
),
x
.
size
(
1
))
if
self
.
shared_pos
else
None
self
.
norm
=
self
.
norm
.
to
(
"cuda"
)
self
.
norm
=
self
.
norm
.
to
(
AI_DEVICE
)
for
block_idx
in
range
(
len
(
self
.
blocks
)):
self
.
block_idx
=
block_idx
if
block_idx
==
0
:
self
.
offload_manager
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
blocks
[
block_idx
].
state_dict
(),
block_idx
,
)
if
block_idx
<
len
(
self
.
blocks
)
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
self
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
x
=
self
.
forward_block_with_offload
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
mask
,
pos_bias
=
e
)
self
.
offload_manager
.
swap_blocks
()
self
.
offload_manager
.
cuda_buffers
[
0
].
load_state_dict
(
self
.
blocks
[
block_idx
].
state_dict
(),
block_idx
,
)
x
=
self
.
forward_block_with_offload
(
self
.
offload_manager
.
cuda_buffers
[
0
],
x
,
mask
,
pos_bias
=
e
)
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
...
...
lightx2v/models/networks/hunyuan_video/infer/feature_caching/transformer_infer.py
View file @
74eeb429
...
...
@@ -6,6 +6,7 @@ import torch
import
torch.nn.functional
as
F
from
lightx2v.models.networks.hunyuan_video.infer.offload.transformer_infer
import
HunyuanVideo15OffloadTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
class
HunyuanVideo15TransformerInferMagCaching
(
HunyuanVideo15OffloadTransformerInfer
):
...
...
@@ -101,8 +102,8 @@ class HunyuanVideo15TransformerInferMagCaching(HunyuanVideo15OffloadTransformerI
def
infer_using_cache
(
self
,
infer_module_out
):
residual_img
=
self
.
residual_cache
[
self
.
scheduler
.
infer_condition
]
residual_txt
=
self
.
residual_cache_txt
[
self
.
scheduler
.
infer_condition
]
infer_module_out
.
img
.
add_
(
residual_img
.
cuda
(
))
infer_module_out
.
txt
.
add_
(
residual_txt
.
cuda
(
))
infer_module_out
.
img
.
add_
(
residual_img
.
to
(
AI_DEVICE
))
infer_module_out
.
txt
.
add_
(
residual_txt
.
to
(
AI_DEVICE
))
def
clear
(
self
):
self
.
accumulated_err
=
{
True
:
0.0
,
False
:
0.0
}
...
...
lightx2v/models/networks/hunyuan_video/infer/offload/transformer_infer.py
View file @
74eeb429
...
...
@@ -2,6 +2,9 @@ import torch
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.models.networks.hunyuan_video.infer.transformer_infer
import
HunyuanVideo15TransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
HunyuanVideo15OffloadTransformerInfer
(
HunyuanVideo15TransformerInfer
):
...
...
@@ -26,6 +29,6 @@ class HunyuanVideo15OffloadTransformerInfer(HunyuanVideo15TransformerInfer):
self
.
offload_manager
.
init_first_buffer
(
weights
.
double_blocks
)
if
block_idx
<
self
.
double_blocks_num
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
weights
.
double_blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
with
torch
_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
infer_module_out
.
img
,
infer_module_out
.
txt
=
self
.
infer_double_block
(
self
.
offload_manager
.
cuda_buffers
[
0
],
infer_module_out
)
self
.
offload_manager
.
swap_blocks
()
lightx2v/models/networks/hunyuan_video/infer/transformer_infer.py
View file @
74eeb429
...
...
@@ -234,16 +234,11 @@ class HunyuanVideo15TransformerInfer(BaseTransformerInfer):
attention_module
=
weights
.
self_attention
,
seq_p_group
=
self
.
seq_p_group
,
use_fp8_comm
=
self
.
seq_p_fp8_comm
,
model_cls
=
self
.
config
[
"model_cls"
],
)
else
:
attn_out
=
weights
.
self_attention
.
apply
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
seqlen
,
max_seqlen_kv
=
seqlen
,
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
cu_seqlens_qkv
,
cu_seqlens_kv
=
cu_seqlens_qkv
,
max_seqlen_q
=
seqlen
,
max_seqlen_kv
=
seqlen
,
model_cls
=
self
.
config
[
"model_cls"
]
)
img_attn
,
txt_attn
=
attn_out
[:
img_seqlen
],
attn_out
[
img_seqlen
:]
...
...
lightx2v/models/networks/hunyuan_video/model.py
View file @
74eeb429
...
...
@@ -32,7 +32,8 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
self
.
seq_p_group
=
None
self
.
cpu_offload
=
self
.
config
.
get
(
"cpu_offload"
,
False
)
self
.
offload_granularity
=
self
.
config
.
get
(
"offload_granularity"
,
"block"
)
self
.
remove_keys
=
[
"byt5_in"
,
"vision_in"
]
self
.
remove_keys
=
[]
self
.
remove_keys
.
extend
([
"byt5_in"
,
"vision_in"
])
self
.
dit_quantized
=
self
.
config
.
get
(
"dit_quantized"
,
False
)
if
self
.
dit_quantized
:
assert
self
.
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
in
[
...
...
@@ -98,7 +99,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
self
.
transformer_infer
=
self
.
transformer_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_buffers
,
self
.
transformer_weights
.
offload_phase_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_
cuda_
buffers
,
self
.
transformer_weights
.
offload_phase_
cuda_
buffers
)
def
set_scheduler
(
self
,
scheduler
):
self
.
scheduler
=
scheduler
...
...
@@ -176,7 +177,7 @@ class HunyuanVideo15Model(CompiledMethodsMixin):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
[
"parallel"
]
:
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
device
=
dist
.
get_rank
()
else
:
device
=
str
(
self
.
device
)
...
...
lightx2v/models/networks/hunyuan_video/weights/transformer_weights.py
View file @
74eeb429
...
...
@@ -24,7 +24,7 @@ class HunyuanVideo15TransformerWeights(WeightModule):
if
config
[
"cpu_offload"
]:
if
config
.
get
(
"offload_granularity"
,
"block"
)
==
"block"
:
self
.
offload_blocks_num
=
2
self
.
offload_block_buffers
=
WeightModuleList
(
self
.
offload_block_
cuda_
buffers
=
WeightModuleList
(
[
MMDoubleStreamBlock
(
i
,
...
...
@@ -36,8 +36,8 @@ class HunyuanVideo15TransformerWeights(WeightModule):
for
i
in
range
(
self
.
offload_blocks_num
)
]
)
self
.
add_module
(
"offload_block_buffers"
,
self
.
offload_block_buffers
)
self
.
offload_phase_buffers
=
None
self
.
add_module
(
"offload_block_
cuda_
buffers"
,
self
.
offload_block_
cuda_
buffers
)
self
.
offload_phase_
cuda_
buffers
=
None
def
non_block_weights_to_cuda
(
self
):
self
.
final_layer
.
to_cuda
()
...
...
@@ -47,23 +47,24 @@ class HunyuanVideo15TransformerWeights(WeightModule):
class
MMDoubleStreamBlock
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
task
=
task
self
.
config
=
config
self
.
is_offload_buffer
=
is_offload_buffer
self
.
create_cuda_buffer
=
create_cuda_buffer
self
.
create_cpu_buffer
=
create_cpu_buffer
self
.
lazy_load
=
False
self
.
lazy_load_file
=
None
self
.
add_module
(
"img_branch"
,
MMDoubleStreamBlockImgBranch
(
block_index
,
task
,
config
,
block_prefix
,
is_offload
_buffer
),
MMDoubleStreamBlockImgBranch
(
block_index
,
task
,
config
,
block_prefix
,
create_cuda_buffer
,
create_cpu
_buffer
),
)
self
.
add_module
(
"txt_branch"
,
MMDoubleStreamBlockTxtBranch
(
block_index
,
task
,
config
,
block_prefix
,
is_offload
_buffer
),
MMDoubleStreamBlockTxtBranch
(
block_index
,
task
,
config
,
block_prefix
,
create_cuda_buffer
,
create_cpu
_buffer
),
)
attention_weights_cls
=
ATTN_WEIGHT_REGISTER
[
self
.
config
[
"attn_type"
]]
self
.
add_module
(
"self_attention"
,
attention_weights_cls
())
...
...
@@ -75,7 +76,7 @@ class MMDoubleStreamBlock(WeightModule):
class
MMDoubleStreamBlockImgBranch
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
task
=
task
...
...
@@ -93,7 +94,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.linear.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mod.linear.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -103,6 +105,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -112,7 +116,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -122,7 +127,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -132,7 +138,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -141,7 +148,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
"img_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_q_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -150,7 +158,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
"img_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_k_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -160,7 +169,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_attn_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -170,6 +180,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -179,7 +191,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -189,7 +202,8 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.img_mlp.fc2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -197,7 +211,7 @@ class MMDoubleStreamBlockImgBranch(WeightModule):
class
MMDoubleStreamBlockTxtBranch
(
WeightModule
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
is_offload
_buffer
=
False
):
def
__init__
(
self
,
block_index
,
task
,
config
,
block_prefix
=
"double_blocks"
,
create_cuda_buffer
=
False
,
create_cpu
_buffer
=
False
):
super
().
__init__
()
self
.
block_index
=
block_index
self
.
task
=
task
...
...
@@ -215,7 +229,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.linear.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mod.linear.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -225,6 +240,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -234,7 +251,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -244,7 +262,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -254,7 +273,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_v.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_v.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -263,7 +283,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
"txt_attn_q_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_q_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -272,7 +293,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
"txt_attn_k_norm"
,
RMS_WEIGHT_REGISTER
[
self
.
rms_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_k_norm.weight"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -282,7 +304,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_proj.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_attn_proj.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -292,6 +315,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -301,7 +326,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc1.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc1.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -311,7 +337,8 @@ class MMDoubleStreamBlockTxtBranch(WeightModule):
MM_WEIGHT_REGISTER
[
self
.
mm_type
](
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc2.weight"
,
f
"
{
block_prefix
}
.
{
self
.
block_index
}
.txt_mlp.fc2.bias"
,
is_offload_buffer
,
create_cuda_buffer
,
create_cpu_buffer
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -333,6 +360,8 @@ class FinalLayerWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.adaLN_modulation.1.weight"
,
"final_layer.adaLN_modulation.1.bias"
,
False
,
False
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -342,6 +371,8 @@ class FinalLayerWeights(WeightModule):
MM_WEIGHT_REGISTER
[
"Default"
](
"final_layer.linear.weight"
,
"final_layer.linear.bias"
,
False
,
False
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
@@ -351,6 +382,8 @@ class FinalLayerWeights(WeightModule):
LN_WEIGHT_REGISTER
[
self
.
ln_type
](
None
,
None
,
False
,
False
,
self
.
lazy_load
,
self
.
lazy_load_file
,
),
...
...
lightx2v/models/networks/qwen_image/infer/offload/transformer_infer.py
View file @
74eeb429
...
...
@@ -2,6 +2,9 @@ import torch
from
lightx2v.common.offload.manager
import
WeightAsyncStreamManager
from
lightx2v.models.networks.qwen_image.infer.transformer_infer
import
QwenImageTransformerInfer
from
lightx2v_platform.base.global_var
import
AI_DEVICE
torch_device_module
=
getattr
(
torch
,
AI_DEVICE
)
class
QwenImageOffloadTransformerInfer
(
QwenImageTransformerInfer
):
...
...
@@ -37,7 +40,7 @@ class QwenImageOffloadTransformerInfer(QwenImageTransformerInfer):
if
block_idx
<
self
.
num_blocks
-
1
:
self
.
offload_manager
.
prefetch_weights
(
block_idx
+
1
,
block_weights
.
blocks
)
with
torch
.
cuda
.
stream
(
self
.
offload_manager
.
compute_stream
):
with
torch
_device_module
.
stream
(
self
.
offload_manager
.
compute_stream
):
encoder_hidden_states
,
hidden_states
=
self
.
infer_block
(
block_weight
=
self
.
offload_manager
.
cuda_buffers
[
0
],
hidden_states
=
hidden_states
,
encoder_hidden_states
=
encoder_hidden_states
,
temb
=
temb
,
image_rotary_emb
=
image_rotary_emb
)
...
...
lightx2v/models/networks/qwen_image/model.py
View file @
74eeb429
...
...
@@ -8,7 +8,6 @@ from safetensors import safe_open
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.utils
import
*
from
lightx2v_platform.base.global_var
import
AI_DEVICE
from
.infer.offload.transformer_infer
import
QwenImageOffloadTransformerInfer
from
.infer.post_infer
import
QwenImagePostInfer
...
...
@@ -125,7 +124,7 @@ class QwenImageTransformerModel:
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
remove_keys
=
self
.
remove_keys
if
hasattr
(
self
,
"remove_keys"
)
else
[]
if
self
.
config
[
"parallel"
]
:
if
self
.
device
.
type
!=
"cpu"
and
dist
.
is_initialized
()
:
device
=
dist
.
get_rank
()
else
:
device
=
str
(
self
.
device
)
...
...
@@ -284,7 +283,7 @@ class QwenImageTransformerModel:
self
.
pre_infer
=
self
.
pre_infer_class
(
self
.
config
)
self
.
post_infer
=
self
.
post_infer_class
(
self
.
config
)
if
hasattr
(
self
.
transformer_infer
,
"offload_manager"
):
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_buffers
,
self
.
transformer_weights
.
offload_phase_buffers
)
self
.
transformer_infer
.
offload_manager
.
init_cuda_buffer
(
self
.
transformer_weights
.
offload_block_
cuda_
buffers
,
self
.
transformer_weights
.
offload_phase_
cuda_
buffers
)
def
to_cpu
(
self
):
self
.
pre_weight
.
to_cpu
()
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment