Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
53c0d05c
Commit
53c0d05c
authored
May 09, 2025
by
helloyongyang
Browse files
update save/load of mm weights
parent
78640ad0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
22 deletions
+17
-22
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+17
-22
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
53c0d05c
...
@@ -49,14 +49,6 @@ class MMWeightTemplate(metaclass=ABCMeta):
...
@@ -49,14 +49,6 @@ class MMWeightTemplate(metaclass=ABCMeta):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
cuda
(
non_blocking
=
non_blocking
)
self
.
bias
=
self
.
bias
.
cuda
(
non_blocking
=
non_blocking
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
weight
.
cpu
().
detach
().
clone
()
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
return
destination
@
MM_WEIGHT_REGISTER
(
"Default"
)
@
MM_WEIGHT_REGISTER
(
"Default"
)
class
MMWeight
(
MMWeightTemplate
):
class
MMWeight
(
MMWeightTemplate
):
...
@@ -64,12 +56,8 @@ class MMWeight(MMWeightTemplate):
...
@@ -64,12 +56,8 @@ class MMWeight(MMWeightTemplate):
super
().
__init__
(
weight_name
,
bias_name
)
super
().
__init__
(
weight_name
,
bias_name
)
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
if
GET_RUNNING_FLAG
()
==
"save_naive_quant"
or
self
.
config
.
get
(
"weight_auto_quant"
,
False
)
or
self
.
config
.
get
(
"mm_type"
,
"Default"
)
==
"Default"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
...
@@ -80,6 +68,14 @@ class MMWeight(MMWeightTemplate):
...
@@ -80,6 +68,14 @@ class MMWeight(MMWeightTemplate):
return
torch
.
mm
(
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
return
torch
.
mm
(
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
return
torch
.
addmm
(
self
.
bias
,
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
return
torch
.
addmm
(
self
.
bias
,
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
weight
.
cpu
().
detach
().
clone
().
t
().
contiguous
()
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
return
destination
@
MM_WEIGHT_REGISTER
(
"Default-Force-FP32"
)
@
MM_WEIGHT_REGISTER
(
"Default-Force-FP32"
)
class
MMWeightForceFP32
(
MMWeight
):
class
MMWeightForceFP32
(
MMWeight
):
...
@@ -106,6 +102,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -106,6 +102,8 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
self
.
load_func
(
weight_dict
)
self
.
load_func
(
weight_dict
)
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
def
load_quantized
(
self
,
weight_dict
):
def
load_quantized
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
...
@@ -118,8 +116,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -118,8 +116,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
else
:
else
:
self
.
load_quantized
(
weight_dict
)
self
.
load_quantized
(
weight_dict
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
...
@@ -131,8 +127,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -131,8 +127,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
else
:
else
:
self
.
load_quantized
(
weight_dict
)
self
.
load_quantized
(
weight_dict
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
...
@@ -141,8 +135,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -141,8 +135,6 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
GET_RUNNING_FLAG
()
==
"save_naive_quant"
or
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
GET_RUNNING_FLAG
()
==
"save_naive_quant"
or
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight
,
self
.
weight_scale
=
self
.
per_block_cast_to_fp8
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
per_block_cast_to_fp8
(
self
.
weight
)
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
else
:
else
:
self
.
load_quantized
(
weight_dict
)
self
.
load_quantized
(
weight_dict
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
...
@@ -193,7 +185,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -193,7 +185,10 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
state_dict
(
self
,
destination
=
None
):
def
state_dict
(
self
,
destination
=
None
):
if
destination
is
None
:
if
destination
is
None
:
destination
=
{}
destination
=
{}
destination
[
self
.
weight_name
]
=
self
.
weight
.
cpu
().
detach
().
clone
()
if
self
.
weight_need_transpose
:
destination
[
self
.
weight_name
]
=
self
.
weight
.
cpu
().
detach
().
clone
().
t
().
contiguous
()
else
:
destination
[
self
.
weight_name
]
=
self
.
weight
.
cpu
().
detach
().
clone
().
contiguous
()
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
destination
[
self
.
bias_name
]
=
self
.
bias
.
cpu
().
detach
().
clone
()
if
hasattr
(
self
,
"weight_scale"
):
if
hasattr
(
self
,
"weight_scale"
):
...
@@ -478,7 +473,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
...
@@ -478,7 +473,7 @@ class MMWeightWint8channelAint8channeldynamicSglActVllm(MMWeightQuantTemplate):
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
weight_dict
=
{
weight_dict
=
{
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
)
.
t
()
,
"xx.weight"
:
torch
.
randn
(
8192
,
4096
).
to
(
torch
.
float8_e4m3fn
),
"xx.bias"
:
torch
.
randn
(
8192
).
to
(
torch
.
bfloat16
),
"xx.bias"
:
torch
.
randn
(
8192
).
to
(
torch
.
bfloat16
),
"xx.weight_scale"
:
torch
.
randn
(
8192
,
1
).
to
(
torch
.
float32
),
"xx.weight_scale"
:
torch
.
randn
(
8192
,
1
).
to
(
torch
.
float32
),
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment