Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
94bd4599
Commit
94bd4599
authored
Aug 06, 2025
by
gushiqiao
Browse files
Support loading bf16 weights and converting them to fp32
parent
6b32b743
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
5 deletions
+15
-5
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+5
-2
lightx2v/models/networks/wan/causvid_model.py
lightx2v/models/networks/wan/causvid_model.py
+4
-1
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+6
-2
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
94bd4599
...
@@ -42,6 +42,7 @@ try:
...
@@ -42,6 +42,7 @@ try:
except
ModuleNotFoundError
:
except
ModuleNotFoundError
:
marlin_cuda_quant
=
None
marlin_cuda_quant
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
self
.
weight_name
=
weight_name
self
.
weight_name
=
weight_name
...
@@ -687,6 +688,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
...
@@ -687,6 +688,7 @@ class MMWeightGGUFQ4K(MMWeightGGUFTemplate):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
@
MM_WEIGHT_REGISTER
(
"W-int4-group128-sym-Marlin"
)
@
MM_WEIGHT_REGISTER
(
"W-int4-group128-sym-Marlin"
)
class
MMWeightWint4group128Marlin
(
MMWeightQuantTemplate
):
class
MMWeightWint4group128Marlin
(
MMWeightQuantTemplate
):
"""
"""
...
@@ -710,14 +712,15 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
...
@@ -710,14 +712,15 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
self
.
pinned_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
self
.
pinned_bias
=
torch
.
empty
(
self
.
bias
.
shape
,
pin_memory
=
True
,
dtype
=
self
.
bias
.
dtype
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
output_tensor
=
torch
.
empty
(
input_tensor
.
shape
[:
-
1
]
+
(
self
.
weight_scale
.
shape
[
1
],),
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
output_tensor
=
torch
.
empty
(
input_tensor
.
shape
[:
-
1
]
+
(
self
.
weight_scale
.
shape
[
1
],),
dtype
=
input_tensor
.
dtype
,
device
=
input_tensor
.
device
)
marlin_cuda_quant
.
mul
(
input_tensor
,
self
.
weight
,
output_tensor
,
self
.
weight_scale
.
half
(),
self
.
workspace
,
-
1
,
-
1
,
-
1
,
-
1
)
marlin_cuda_quant
.
mul
(
input_tensor
,
self
.
weight
,
output_tensor
,
self
.
weight_scale
.
half
(),
self
.
workspace
,
-
1
,
-
1
,
-
1
,
-
1
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
return
output_tensor
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
weight_dict
=
{
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
...
...
lightx2v/models/networks/wan/causvid_model.py
View file @
94bd4599
...
@@ -37,7 +37,10 @@ class WanCausVidModel(WanModel):
...
@@ -37,7 +37,10 @@ class WanCausVidModel(WanModel):
if
os
.
path
.
exists
(
safetensors_path
):
if
os
.
path
.
exists
(
safetensors_path
):
with
safe_open
(
safetensors_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
safetensors_path
,
framework
=
"pt"
)
as
f
:
weight_dict
=
{
weight_dict
=
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
())).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
()))
.
pin_memory
()
.
to
(
self
.
device
)
for
key
in
f
.
keys
()
}
}
return
weight_dict
return
weight_dict
...
...
lightx2v/models/networks/wan/model.py
View file @
94bd4599
import
os
import
json
import
json
import
os
import
torch
import
torch
import
torch.distributed
as
dist
import
torch.distributed
as
dist
from
loguru
import
logger
from
loguru
import
logger
...
@@ -103,7 +104,10 @@ class WanModel:
...
@@ -103,7 +104,10 @@ class WanModel:
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
def
_load_safetensor_to_dict
(
self
,
file_path
,
unified_dtype
,
sensitive_layer
):
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
with
safe_open
(
file_path
,
framework
=
"pt"
)
as
f
:
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
())).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()}
return
{
key
:
(
f
.
get_tensor
(
key
).
to
(
GET_DTYPE
())
if
unified_dtype
or
all
(
s
not
in
key
for
s
in
sensitive_layer
)
else
f
.
get_tensor
(
key
).
to
(
GET_SENSITIVE_DTYPE
())).
pin_memory
().
to
(
self
.
device
)
for
key
in
f
.
keys
()
}
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
def
_load_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
safetensors_path
=
find_hf_model_path
(
self
.
config
,
self
.
model_path
,
"dit_original_ckpt"
,
subdir
=
"original"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment