Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
f085ede3
"...resnet50_tensorflow.git" did not exist on "05a79f5a2227ff89a67c6ebd862f4818e6301f97"
Commit
f085ede3
authored
Sep 16, 2025
by
gushiqiao
Committed by
GitHub
Sep 16, 2025
Browse files
[Fix] Fix high peak memory bug (#313)
parent
6d9e6c0a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
61 additions
and
43 deletions
+61
-43
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+61
-43
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
f085ede3
...
@@ -84,19 +84,25 @@ class MMWeight(MMWeightTemplate):
...
@@ -84,19 +84,25 @@ class MMWeight(MMWeightTemplate):
def
load
(
self
,
weight_dict
):
def
load
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cuda"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
()
if
self
.
bias_name
is
not
None
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
t
().
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
weight
.
copy_
(
weight_dict
[
self
.
weight_name
].
t
())
weight_shape
=
weight_dict
[
self
.
weight_name
].
t
().
shape
if
self
.
bias_name
is
not
None
:
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
self
.
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
).
to
(
device
)
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
weight
=
self
.
weight
.
copy_
(
weight_dict
[
self
.
weight_name
].
t
())
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
if
self
.
bias_name
is
not
None
:
else
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
self
.
bias
=
None
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
).
to
(
device
)
self
.
bias
=
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
self
.
bias
=
None
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
def
_calculate_size
(
self
):
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
...
@@ -149,7 +155,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -149,7 +155,7 @@ class MMWeightQuantTemplate(MMWeightTemplate):
# weight load functions
# weight load functions
# =========================
# =========================
def
load_from_disk
(
self
):
def
load_from_disk
(
self
):
# Need Rewrite
if
not
torch
.
_dynamo
.
is_compiling
():
if
not
torch
.
_dynamo
.
is_compiling
():
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
pin_memory
()
self
.
weight
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_name
).
pin_memory
()
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
().
pin_memory
()
self
.
weight_scale
=
self
.
lazy_load_file
.
get_tensor
(
self
.
weight_scale_name
).
float
().
pin_memory
()
...
@@ -180,28 +186,25 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -180,28 +186,25 @@ class MMWeightQuantTemplate(MMWeightTemplate):
def
_calculate_size
(
self
):
def
_calculate_size
(
self
):
if
self
.
bias
is
not
None
:
if
self
.
bias
is
not
None
:
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
+
self
.
bias
.
numel
()
*
self
.
bias
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
return
self
.
weight
.
numel
()
*
self
.
weight
.
element_size
()
+
self
.
weight_scale
.
numel
()
*
self
.
weight_scale
.
element_size
()
def
load_quantized
(
self
,
weight_dict
):
def
load_quantized
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
device
=
weight_dict
[
self
.
weight_name
].
device
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
if
device
.
type
==
"cuda"
:
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
).
to
(
device
)
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
].
float
()
self
.
weight
=
self
.
weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
weight_scale_dtype
=
torch
.
float
self
.
weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
).
to
(
device
)
self
.
weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
self
.
weight_scale
=
self
.
weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
if
self
.
bias_name
is
not
None
:
weight_scale_dtype
=
torch
.
float
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
self
.
weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
).
to
(
device
)
self
.
bias
=
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
self
.
bias
=
None
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
def
load_fp8_perchannel_sym
(
self
,
weight_dict
):
def
load_fp8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
...
@@ -215,10 +218,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -215,10 +218,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
device
=
weight_dict
[
self
.
bias_name
].
device
device
=
weight_dict
[
self
.
bias_name
].
device
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
if
device
.
type
==
"cuda"
:
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
weight_dict
[
self
.
bias_name
]
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
).
to
(
device
)
elif
device
.
type
==
"cpu"
:
self
.
bias
=
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -234,10 +242,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -234,10 +242,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
device
=
weight_dict
[
self
.
bias_name
].
device
device
=
weight_dict
[
self
.
bias_name
].
device
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
if
device
.
type
==
"cuda"
:
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
weight_dict
[
self
.
bias_name
]
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
).
to
(
device
)
elif
device
.
type
==
"cpu"
:
self
.
bias
=
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -250,10 +263,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -250,10 +263,15 @@ class MMWeightQuantTemplate(MMWeightTemplate):
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
device
=
weight_dict
[
self
.
bias_name
].
device
device
=
weight_dict
[
self
.
bias_name
].
device
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
if
device
.
type
==
"cuda"
:
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
weight_dict
[
self
.
bias_name
]
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
).
to
(
device
)
elif
device
.
type
==
"cpu"
:
self
.
bias
=
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
@@ -735,8 +753,8 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
...
@@ -735,8 +753,8 @@ class MMWeightWint4group128Marlin(MMWeightQuantTemplate):
if
self
.
bias_name
is
not
None
:
if
self
.
bias_name
is
not
None
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
.
to
(
device
)
self
.
bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
bias
=
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
self
.
bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
else
:
self
.
bias
=
None
self
.
bias
=
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment