Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6b32b743
"examples/model_compress/vscode:/vscode.git/clone" did not exist on "6e643b00a22ae251ca120936fda34420f8d88fc5"
Commit
6b32b743
authored
Aug 06, 2025
by
gushiqiao
Browse files
Support loading bf16 weights and converting them to fp32
parent
978e3b32
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
7 deletions
+41
-7
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+34
-0
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+1
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+6
-6
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
6b32b743
...
...
@@ -37,6 +37,10 @@ try:
except
ImportError
:
gguf
=
None
try
:
import
marlin_cuda_quant
except
ModuleNotFoundError
:
marlin_cuda_quant
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
...
...
@@ -683,7 +687,37 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
@
MM_WEIGHT_REGISTER
(
"W-int4-group128-sym-Marlin"
)
class
MMWeightWint4group128Marlin
(
MMWeightQuantTemplate
):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_quantized
def
load
(
self
,
weight_dict
):
assert
not
self
.
lazy_load
self
.
load_func
(
weight_dict
)
self
.
workspace
=
weight_dict
[
f
"
{
self
.
weight_name
}
_workspace"
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
self
.
pinned_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
output_tensor
=
torch
.
empty
(
input_tensor
.
shape
[:
-
1
]
+
(
self
.
weight_scale
.
shape
[
1
],),
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
marlin_cuda_quant
.
mul
(
input_tensor
,
self
.
weight
,
output_tensor
,
self
.
weight_scale
.
half
(),
self
.
workspace
,
-
1
,
-
1
,
-
1
,
-
1
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
if
__name__
==
"__main__"
:
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
...
...
lightx2v/models/networks/wan/causvid_model.py
View file @
6b32b743
...
...
@@ -37,7 +37,7 @@ class WanCausVidModel(WanModel):
if
os
.
path
.
exists
(
safetensors_path
):
with
safe_open
(
safetensors_path
,
framework
=
"pt"
)
as
f
:
weight_dict
=
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)
.
to
(
GET_SENSITIVE_DTYPE
())
).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()
}
return
weight_dict
...
...
lightx2v/models/networks/wan/model.py
View file @
6b32b743
import
os
import
json
import
torch
import
torch.distributed
as
dist
from
loguru
import
logger
...
...
@@ -103,7 +103,7 @@ class WanModel:
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)
.
to
(
GET_SENSITIVE_DTYPE
())
).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
...
...
@@ -134,11 +134,11 @@ class WanModel:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
)
as
f
:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
]
:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_SENSITIVE_DTYPE
()).
to
(
self
.
device
)
else
:
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
...
...
@@ -152,11 +152,11 @@ class WanModel:
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
]
:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_SENSITIVE_DTYPE
()).
to
(
self
.
device
)
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment