Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
6b32b743
Commit
6b32b743
authored
Aug 06, 2025
by
gushiqiao
Browse files
Support loading bf16 weights and converting them to fp32
parent
978e3b32
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
41 additions
and
7 deletions
+41
-7
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+34
-0
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+1
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+6
-6
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
6b32b743
...
@@ -37,6 +37,10 @@ try:
...
@@ -37,6 +37,10 @@ try:
except
ImportError
:
except
ImportError
:
gguf
=
None
gguf
=
None
try
:
import
marlin_cuda_quant
except
ModuleNotFoundError
:
marlin_cuda_quant
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
...
@@ -683,6 +687,36 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
...
@@ -683,6 +687,36 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
@
MM_WEIGHT_REGISTER
(
"W-int4-group128-sym-Marlin"
)
class
MMWeightWint4group128Marlin
(
MMWeightQuantTemplate
):
"""
Name: "W-int4-group128-sym-Marlin
Quant int4 x FP16:
Weight: int4 pergroup sym
Kernel: Marlin
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_quantized
def
load
(
self
,
weight_dict
):
assert
not
self
.
lazy_load
self
.
load_func
(
weight_dict
)
self
.
workspace
=
weight_dict
[
f
"
{
self
.
weight_name
}
_workspace"
]
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
self
.
pinned_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
else
:
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
output_tensor
=
torch
.
empty
(
input_tensor
.
shape
[:
-
1
]
+
(
self
.
weight_scale
.
shape
[
1
],),
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
marlin_cuda_quant
.
mul
(
input_tensor
,
self
.
weight
,
output_tensor
,
self
.
weight_scale
.
half
(),
self
.
workspace
,
-
1
,
-
1
,
-
1
,
-
1
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
weight_dict
=
{
weight_dict
=
{
...
...
lightx2v/models/networks/wan/causvid_model.py
View file @
6b32b743
...
@@ -37,7 +37,7 @@ class WanCausVidModel(WanModel):
...
@@ -37,7 +37,7 @@ class WanCausVidModel(WanModel):
if
os
.
path
.
exists
(
safetensors_path
):
if
os
.
path
.
exists
(
safetensors_path
):
with
safe_open
(
safetensors_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
safetensors_path
,
framework
=
"pt"
)
as
f
:
weight_dict
=
{
weight_dict
=
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)
.
to
(
GET_SENSITIVE_DTYPE
())
).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()
}
}
return
weight_dict
return
weight_dict
...
...
lightx2v/models/networks/wan/model.py
View file @
6b32b743
import
os
import
os
import
json
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
...
@@ -103,7 +103,7 @@ class WanModel:
...
@@ -103,7 +103,7 @@ class WanModel:
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
)
.
to
(
GET_SENSITIVE_DTYPE
())
).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
...
@@ -134,11 +134,11 @@ class WanModel:
...
@@ -134,11 +134,11 @@ class WanModel:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
)
as
f
:
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
logger
.
info
(
f
"Loading weights from
{
safetensor_path
}
"
)
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
]
:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
else
:
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_SENSITIVE_DTYPE
()).
to
(
self
.
device
)
else
:
else
:
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
...
@@ -152,11 +152,11 @@ class WanModel:
...
@@ -152,11 +152,11 @@ class WanModel:
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
safetensor_path
=
os
.
path
.
join
(
lazy_load_model_path
,
"non_block.safetensors"
)
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
with
safe_open
(
safetensor_path
,
framework
=
"pt"
,
device
=
"cpu"
)
as
f
:
for
k
in
f
.
keys
():
for
k
in
f
.
keys
():
if
f
.
get_tensor
(
k
).
dtype
==
torch
.
float
:
if
f
.
get_tensor
(
k
).
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
,
torch
.
float
]
:
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
if
unified_dtype
or
all
(
s
not
in
k
for
s
in
sensitive_layer
):
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_DTYPE
()).
to
(
self
.
device
)
else
:
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
GET_SENSITIVE_DTYPE
()).
to
(
self
.
device
)
else
:
else
:
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
pre_post_weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
pin_memory
().
to
(
self
.
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment