Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
e106ff67
Unverified
Commit
e106ff67
authored
Oct 30, 2025
by
Bilang ZHANG
Committed by
GitHub
Oct 30, 2025
Browse files
support lightx2v-kernel and update convert.py (#413)
parent
3aab9893
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
571 additions
and
4 deletions
+571
-4
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+367
-0
lightx2v/models/networks/wan/model.py
lightx2v/models/networks/wan/model.py
+23
-1
lightx2v/models/networks/wan/weights/transformer_weights.py
lightx2v/models/networks/wan/weights/transformer_weights.py
+2
-0
lightx2v/models/runners/default_runner.py
lightx2v/models/runners/default_runner.py
+6
-0
lightx2v/utils/global_paras.py
lightx2v/utils/global_paras.py
+1
-0
lightx2v/utils/registry_factory.py
lightx2v/utils/registry_factory.py
+1
-0
tools/convert/converter.py
tools/convert/converter.py
+38
-3
tools/convert/quant/__init__.py
tools/convert/quant/__init__.py
+1
-0
tools/convert/quant/quant.py
tools/convert/quant/quant.py
+132
-0
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
e106ff67
...
@@ -3,9 +3,27 @@ from abc import ABCMeta, abstractmethod
...
@@ -3,9 +3,27 @@ from abc import ABCMeta, abstractmethod
import
torch
import
torch
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.global_paras
import
CALIB
from
lightx2v.utils.quant_utils
import
FloatQuantizer
,
IntegerQuantizer
from
lightx2v.utils.quant_utils
import
FloatQuantizer
,
IntegerQuantizer
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
try
:
from
lightx2v_kernel.gemm
import
(
cutlass_scaled_mxfp4_mm
,
cutlass_scaled_mxfp6_mxfp8_mm
,
cutlass_scaled_mxfp8_mm
,
cutlass_scaled_nvfp4_mm
,
scaled_mxfp4_quant
,
scaled_mxfp6_quant
,
scaled_mxfp8_quant
,
scaled_nvfp4_quant
,
)
except
ImportError
:
scaled_nvfp4_quant
,
cutlass_scaled_nvfp4_mm
=
None
,
None
scaled_mxfp4_quant
,
cutlass_scaled_mxfp4_mm
=
None
,
None
scaled_mxfp6_quant
,
cutlass_scaled_mxfp6_mxfp8_mm
=
None
,
None
scaled_mxfp8_quant
,
cutlass_scaled_mxfp8_mm
=
None
,
None
try
:
try
:
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
except
ImportError
:
except
ImportError
:
...
@@ -267,6 +285,179 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -267,6 +285,179 @@ class MMWeightQuantTemplate(MMWeightTemplate):
self
.
bias
=
None
self
.
bias
=
None
self
.
pin_bias
=
None
self
.
pin_bias
=
None
def
load_mxfp4
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
().
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp4_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cuda"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
if
self
.
bias_name
is
not
None
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_mxfp6
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
().
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp6_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cuda"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
if
self
.
bias_name
is
not
None
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_mxfp8
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
device
=
weight_dict
[
self
.
weight_name
].
device
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
().
to
(
torch
.
bfloat16
)
self
.
weight
,
self
.
weight_scale
=
scaled_mxfp8_quant
(
self
.
weight
)
self
.
weight
,
self
.
weight_scale
=
self
.
weight
.
to
(
device
),
self
.
weight_scale
.
to
(
device
)
else
:
device
=
weight_dict
[
self
.
weight_name
].
device
if
device
.
type
==
"cuda"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
if
self
.
bias_name
is
not
None
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_nvfp4
(
self
,
weight_dict
):
device
=
weight_dict
[
self
.
weight_name
].
device
input_absmax
=
weight_dict
[
self
.
weight_name
.
replace
(
".weight"
,
".input_absmax"
)]
input_global_scale
=
(
2688.0
/
input_absmax
).
to
(
torch
.
float32
)
weight_global_scale
=
weight_dict
[
f
"
{
self
.
weight_name
}
_global_scale"
]
alpha
=
1.0
/
(
input_global_scale
*
weight_global_scale
)
if
device
.
type
==
"cuda"
:
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight_scale
=
weight_dict
[
self
.
weight_scale_name
]
self
.
input_global_scale
=
input_global_scale
self
.
alpha
=
alpha
elif
device
.
type
==
"cpu"
:
weight_shape
=
weight_dict
[
self
.
weight_name
].
shape
weight_dtype
=
weight_dict
[
self
.
weight_name
].
dtype
self
.
pin_weight
=
torch
.
empty
(
weight_shape
,
pin_memory
=
True
,
dtype
=
weight_dtype
)
self
.
pin_weight
.
copy_
(
weight_dict
[
self
.
weight_name
])
weight_scale_shape
=
weight_dict
[
self
.
weight_scale_name
].
shape
weight_scale_dtype
=
weight_dict
[
self
.
weight_scale_name
].
dtype
self
.
pin_weight_scale
=
torch
.
empty
(
weight_scale_shape
,
pin_memory
=
True
,
dtype
=
weight_scale_dtype
)
self
.
pin_weight_scale
.
copy_
(
weight_dict
[
self
.
weight_scale_name
])
input_global_scale_shape
=
input_global_scale
.
shape
input_global_scale_dtype
=
input_global_scale
.
dtype
self
.
pin_input_global_scale
=
torch
.
empty
(
input_global_scale_shape
,
pin_memory
=
True
,
dtype
=
input_global_scale_dtype
)
self
.
pin_input_global_scale
.
copy_
(
input_global_scale
)
alpha_shape
=
alpha
.
shape
alpha_dtype
=
alpha
.
dtype
self
.
pin_alpha
=
torch
.
empty
(
alpha_shape
,
pin_memory
=
True
,
dtype
=
alpha_dtype
)
self
.
pin_alpha
.
copy_
(
alpha
)
del
weight_dict
[
self
.
weight_name
]
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
if
self
.
bias_name
is
not
None
:
device
=
weight_dict
[
self
.
bias_name
].
device
if
device
.
type
==
"cuda"
:
self
.
bias
=
weight_dict
[
self
.
bias_name
]
elif
device
.
type
==
"cpu"
:
bias_shape
=
weight_dict
[
self
.
bias_name
].
shape
bias_dtype
=
weight_dict
[
self
.
bias_name
].
dtype
self
.
pin_bias
=
torch
.
empty
(
bias_shape
,
pin_memory
=
True
,
dtype
=
bias_dtype
)
self
.
pin_bias
.
copy_
(
weight_dict
[
self
.
bias_name
])
else
:
raise
ValueError
(
f
"Unsupported device type:
{
device
.
type
}
, only 'cpu' and 'cuda' are supported"
)
else
:
self
.
bias
=
None
self
.
pin_bias
=
None
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
False
):
self
.
weight
=
weight_dict
[
self
.
weight_name
]
self
.
weight
=
weight_dict
[
self
.
weight_name
]
...
@@ -325,6 +516,18 @@ class MMWeightQuantTemplate(MMWeightTemplate):
...
@@ -325,6 +516,18 @@ class MMWeightQuantTemplate(MMWeightTemplate):
input_tensor_quant
,
input_tensor_scale
,
_
=
ops
.
scaled_int8_quant
(
x
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
input_tensor_quant
,
input_tensor_scale
,
_
=
ops
.
scaled_int8_quant
(
x
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
return
input_tensor_quant
,
input_tensor_scale
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_nvfp4
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
scaled_nvfp4_quant
(
x
,
self
.
input_global_scale
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_mxfp4
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
scaled_mxfp4_quant
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_mxfp8
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
scaled_mxfp8_quant
(
x
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannelgroup128_sym_deepgemm
(
self
,
x
):
def
act_quant_fp8_perchannelgroup128_sym_deepgemm
(
self
,
x
):
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
m
,
n
=
x
.
shape
...
@@ -431,6 +634,170 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
...
@@ -431,6 +634,170 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightQuantTemplate):
return
output_tensor
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"mxfp4"
)
class
MMWeightWmxfp4Amxfp4dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-mxfp4-A-mxfp4-dynamic
Quant MM:
Weight: mxfp4
Act: mxfp4
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_mxfp4
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp4
self
.
set_alpha
()
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
self
.
alpha
=
self
.
alpha
.
to
(
self
.
weight
.
device
)
output_tensor
=
cutlass_scaled_mxfp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"mxfp6-mxfp8"
)
class
MMWeightWmxfp6Amxfp8dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-mxfp6-A-nvfp8-dynamic
Quant MM:
Weight: mxfp6
Act: mxfp8
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_mxfp6
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
self
.
set_alpha
()
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
self
.
alpha
=
self
.
alpha
.
to
(
self
.
weight
.
device
)
output_tensor
=
cutlass_scaled_mxfp6_mxfp8_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"mxfp8"
)
class
MMWeightWmxfp8Amxfp8dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-mxfp8-A-nvfp8-dynamic
Quant MM:
Weight: mxfp8
Act: mxfp8
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_mxfp8
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_mxfp8
self
.
set_alpha
()
def
set_alpha
(
self
):
self
.
alpha
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
self
.
alpha
=
self
.
alpha
.
to
(
self
.
weight
.
device
)
output_tensor
=
cutlass_scaled_mxfp8_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"nvfp4"
)
class
MMWeightWnvfp4Anvfp4dynamic
(
MMWeightQuantTemplate
):
"""
Name: W-nvfp4-A-nvfp4-dynamic
Quant MM:
Weight: nvfp4
Act: nvfp4
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
load_func
=
self
.
load_nvfp4
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_nvfp4
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
cutlass_scaled_nvfp4_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
alpha
=
self
.
alpha
,
bias
=
self
.
bias
)
return
output_tensor
def
to_cuda
(
self
,
non_blocking
=
False
):
self
.
weight
=
self
.
pin_weight
.
cuda
(
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_weight_scale"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
cuda
(
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
pin_alpha
.
cuda
(
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"pin_bias"
)
and
self
.
pin_bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
cuda
(
non_blocking
=
non_blocking
)
def
to_cpu
(
self
,
non_blocking
=
False
):
if
hasattr
(
self
,
"pin_weight"
):
self
.
weight
=
self
.
pin_weight
.
copy_
(
self
.
weight
,
non_blocking
=
non_blocking
).
cpu
()
if
hasattr
(
self
,
"weight_scale_name"
):
self
.
weight_scale
=
self
.
pin_weight_scale
.
copy_
(
self
.
weight_scale
,
non_blocking
=
non_blocking
).
cpu
()
self
.
input_global_scale
=
self
.
pin_input_global_scale
.
copy_
(
self
.
input_global_scale
,
non_blocking
=
non_blocking
).
cpu
()
self
.
alpha
=
self
.
pin_alpha
.
copy_
(
self
.
alpha
,
non_blocking
=
non_blocking
).
cpu
()
if
self
.
bias
is
not
None
:
self
.
bias
=
self
.
pin_bias
.
copy_
(
self
.
bias
,
non_blocking
=
non_blocking
).
cpu
()
else
:
self
.
weight
=
self
.
weight
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"weight_scale"
):
self
.
weight_scale
=
self
.
weight_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
self
.
input_global_scale
=
self
.
input_global_scale
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
self
.
alpha
=
self
.
alpha
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
if
hasattr
(
self
,
"bias"
)
and
self
.
bias
is
not
None
:
self
.
bias
=
self
.
bias
.
to
(
"cpu"
,
non_blocking
=
non_blocking
)
@
MM_WEIGHT_REGISTER
(
"Calib"
)
class
MMCalibNvfp4
(
MMWeight
):
"""
Name: calib
Calib:
absmax: torch.max(torch.abs(input_tensor))
"""
def
__init__
(
self
,
weight_name
,
bias_name
,
lazy_load
=
False
,
lazy_load_file
=
None
):
super
().
__init__
(
weight_name
,
bias_name
,
lazy_load
,
lazy_load_file
)
self
.
running_absmax
=
None
self
.
count
=
0
self
.
decay
=
0.9
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
,
device
=
input_tensor
.
dtype
,
input_tensor
.
device
current_absmax
=
torch
.
max
(
torch
.
abs
(
input_tensor
)).
to
(
"cpu"
)
if
self
.
count
%
2
==
0
:
if
self
.
running_absmax
is
None
:
self
.
running_absmax
=
current_absmax
else
:
self
.
running_absmax
=
self
.
decay
*
self
.
running_absmax
+
(
1
-
self
.
decay
)
*
current_absmax
CALIB
[
"absmax"
][
self
.
weight_name
]
=
self
.
running_absmax
self
.
count
=
self
.
count
+
1
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
if
self
.
bias
is
None
:
return
torch
.
mm
(
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
return
torch
.
addmm
(
self
.
bias
,
input_tensor
,
self
.
weight
,
out
=
output_tensor
)
@
MM_WEIGHT_REGISTER
(
"fp8-q8f"
)
@
MM_WEIGHT_REGISTER
(
"fp8-q8f"
)
class
MMWeightWfp8channelAfp8channeldynamicQ8F
(
MMWeightQuantTemplate
):
class
MMWeightWfp8channelAfp8channeldynamicQ8F
(
MMWeightQuantTemplate
):
"""
"""
...
...
lightx2v/models/networks/wan/model.py
View file @
e106ff67
...
@@ -60,7 +60,21 @@ class WanModel(CompiledMethodsMixin):
...
@@ -60,7 +60,21 @@ class WanModel(CompiledMethodsMixin):
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
clean_cuda_cache
=
self
.
config
.
get
(
"clean_cuda_cache"
,
False
)
self
.
dit_quantized
=
self
.
config
.
get
(
"dit_quantized"
,
False
)
self
.
dit_quantized
=
self
.
config
.
get
(
"dit_quantized"
,
False
)
if
self
.
dit_quantized
:
if
self
.
dit_quantized
:
assert
self
.
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
in
[
"Default-Force-FP32"
,
"fp8-vllm"
,
"int8-vllm"
,
"fp8-q8f"
,
"int8-q8f"
,
"fp8-b128-deepgemm"
,
"fp8-sgl"
,
"int8-sgl"
,
"int8-torchao"
]
assert
self
.
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
in
[
"Default-Force-FP32"
,
"fp8-vllm"
,
"int8-vllm"
,
"fp8-q8f"
,
"int8-q8f"
,
"fp8-b128-deepgemm"
,
"fp8-sgl"
,
"int8-sgl"
,
"int8-torchao"
,
"nvfp4"
,
"mxfp4"
,
"mxfp6-mxfp8"
,
"mxfp8"
,
]
self
.
device
=
device
self
.
device
=
device
self
.
_init_infer_class
()
self
.
_init_infer_class
()
self
.
_init_weights
()
self
.
_init_weights
()
...
@@ -169,6 +183,7 @@ class WanModel(CompiledMethodsMixin):
...
@@ -169,6 +183,7 @@ class WanModel(CompiledMethodsMixin):
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
safetensors_path
,
"*.safetensors"
))
safetensors_files
=
glob
.
glob
(
os
.
path
.
join
(
safetensors_path
,
"*.safetensors"
))
else
:
else
:
safetensors_files
=
[
safetensors_path
]
safetensors_files
=
[
safetensors_path
]
safetensors_path
=
os
.
path
.
dirname
(
safetensors_path
)
weight_dict
=
{}
weight_dict
=
{}
for
safetensor_path
in
safetensors_files
:
for
safetensor_path
in
safetensors_files
:
...
@@ -192,6 +207,13 @@ class WanModel(CompiledMethodsMixin):
...
@@ -192,6 +207,13 @@ class WanModel(CompiledMethodsMixin):
else
:
else
:
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
self
.
device
)
weight_dict
[
k
]
=
f
.
get_tensor
(
k
).
to
(
self
.
device
)
if
self
.
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
==
"nvfp4"
:
calib_path
=
os
.
path
.
join
(
safetensors_path
,
"calib.pt"
)
logger
.
info
(
f
"[CALIB] Loaded calibration data from:
{
calib_path
}
"
)
calib_data
=
torch
.
load
(
calib_path
,
map_location
=
"cpu"
)
for
k
,
v
in
calib_data
[
"absmax"
].
items
():
weight_dict
[
k
.
replace
(
".weight"
,
".input_absmax"
)]
=
v
.
to
(
self
.
device
)
return
weight_dict
return
weight_dict
def
_load_quant_split_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
# Need rewrite
def
_load_quant_split_ckpt
(
self
,
unified_dtype
,
sensitive_layer
):
# Need rewrite
...
...
lightx2v/models/networks/wan/weights/transformer_weights.py
View file @
e106ff67
...
@@ -21,6 +21,8 @@ class WanTransformerWeights(WeightModule):
...
@@ -21,6 +21,8 @@ class WanTransformerWeights(WeightModule):
self
.
mm_type
=
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
self
.
mm_type
=
config
.
get
(
"dit_quant_scheme"
,
"Default"
)
if
self
.
mm_type
!=
"Default"
:
if
self
.
mm_type
!=
"Default"
:
assert
config
.
get
(
"dit_quantized"
)
is
True
assert
config
.
get
(
"dit_quantized"
)
is
True
if
config
.
get
(
"do_mm_calib"
,
False
):
self
.
mm_type
=
"Calib"
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
blocks
=
WeightModuleList
([
WanTransformerAttentionBlock
(
i
,
self
.
task
,
self
.
mm_type
,
self
.
config
)
for
i
in
range
(
self
.
blocks_num
)])
self
.
add_module
(
"blocks"
,
self
.
blocks
)
self
.
add_module
(
"blocks"
,
self
.
blocks
)
...
...
lightx2v/models/runners/default_runner.py
View file @
e106ff67
...
@@ -11,6 +11,7 @@ from requests.exceptions import RequestException
...
@@ -11,6 +11,7 @@ from requests.exceptions import RequestException
from
lightx2v.server.metrics
import
monitor_cli
from
lightx2v.server.metrics
import
monitor_cli
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.envs
import
*
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.generate_task_id
import
generate_task_id
from
lightx2v.utils.global_paras
import
CALIB
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.memory_profiler
import
peak_memory_decorator
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.profiler
import
*
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
from
lightx2v.utils.utils
import
save_to_video
,
vae_to_comfyui_image
...
@@ -176,6 +177,10 @@ class DefaultRunner(BaseRunner):
...
@@ -176,6 +177,10 @@ class DefaultRunner(BaseRunner):
self
.
model
.
transformer_weights
.
clear
()
self
.
model
.
transformer_weights
.
clear
()
self
.
model
.
pre_weight
.
clear
()
self
.
model
.
pre_weight
.
clear
()
del
self
.
model
del
self
.
model
if
self
.
config
.
get
(
"do_mm_calib"
,
False
):
calib_path
=
os
.
path
.
join
(
os
.
getcwd
(),
"calib.pt"
)
torch
.
save
(
CALIB
,
calib_path
)
logger
.
info
(
f
"[CALIB] Saved calibration data successfully to:
{
calib_path
}
"
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
...
@@ -258,6 +263,7 @@ class DefaultRunner(BaseRunner):
...
@@ -258,6 +263,7 @@ class DefaultRunner(BaseRunner):
def
init_run
(
self
):
def
init_run
(
self
):
self
.
gen_video_final
=
None
self
.
gen_video_final
=
None
self
.
get_video_segment_num
()
self
.
get_video_segment_num
()
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
if
self
.
config
.
get
(
"lazy_load"
,
False
)
or
self
.
config
.
get
(
"unload_modules"
,
False
):
self
.
model
=
self
.
load_transformer
()
self
.
model
=
self
.
load_transformer
()
...
...
lightx2v/utils/global_paras.py
0 → 100644
View file @
e106ff67
CALIB
=
{
"absmax"
:
{}}
lightx2v/utils/registry_factory.py
View file @
e106ff67
...
@@ -51,5 +51,6 @@ LN_WEIGHT_REGISTER = Register()
...
@@ -51,5 +51,6 @@ LN_WEIGHT_REGISTER = Register()
CONV3D_WEIGHT_REGISTER
=
Register
()
CONV3D_WEIGHT_REGISTER
=
Register
()
CONV2D_WEIGHT_REGISTER
=
Register
()
CONV2D_WEIGHT_REGISTER
=
Register
()
TENSOR_REGISTER
=
Register
()
TENSOR_REGISTER
=
Register
()
CONVERT_WEIGHT_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
RUNNER_REGISTER
=
Register
()
tools/convert/converter.py
View file @
e106ff67
...
@@ -11,11 +11,21 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
...
@@ -11,11 +11,21 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
import
torch
import
torch
from
loguru
import
logger
from
loguru
import
logger
from
lora_loader
import
LoRALoader
try
:
from
lora_loader
import
LoRALoader
except
ImportError
:
pass
from
safetensors
import
safe_open
from
safetensors
import
safe_open
from
safetensors
import
torch
as
st
from
safetensors
import
torch
as
st
from
tqdm
import
tqdm
from
tqdm
import
tqdm
try
:
from
lightx2v.utils.registry_factory
import
CONVERT_WEIGHT_REGISTER
except
ImportError
:
pass
from
tools.convert.quant
import
*
def
get_key_mapping_rules
(
direction
,
model_type
):
def
get_key_mapping_rules
(
direction
,
model_type
):
if
model_type
==
"wan_dit"
:
if
model_type
==
"wan_dit"
:
...
@@ -349,7 +359,17 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
...
@@ -349,7 +359,17 @@ def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
def
quantize_model
(
def
quantize_model
(
weights
,
w_bit
=
8
,
target_keys
=
[
"attn"
,
"ffn"
],
adapter_keys
=
None
,
key_idx
=
2
,
ignore_key
=
None
,
linear_dtype
=
torch
.
int8
,
non_linear_dtype
=
torch
.
float
,
comfyui_mode
=
False
,
comfyui_keys
=
[]
weights
,
w_bit
=
8
,
target_keys
=
[
"attn"
,
"ffn"
],
adapter_keys
=
None
,
key_idx
=
2
,
ignore_key
=
None
,
linear_dtype
=
torch
.
int8
,
non_linear_dtype
=
torch
.
float
,
comfyui_mode
=
False
,
comfyui_keys
=
[],
linear_quant_type
=
None
,
):
):
"""
"""
Quantize model weights in-place
Quantize model weights in-place
...
@@ -414,7 +434,13 @@ def quantize_model(
...
@@ -414,7 +434,13 @@ def quantize_model(
original_size
+=
original_tensor_size
original_size
+=
original_tensor_size
# Quantize tensor and store results
# Quantize tensor and store results
w_q
,
scales
=
quantize_tensor
(
tensor
,
w_bit
,
linear_dtype
,
comfyui_mode
)
if
linear_quant_type
:
quantizer
=
CONVERT_WEIGHT_REGISTER
[
linear_quant_type
](
tensor
)
w_q
,
scales
,
extra
=
quantizer
.
weight_quant_func
(
tensor
)
weight_global_scale
=
extra
.
get
(
"weight_global_scale"
,
None
)
# For nvfp4
else
:
w_q
,
scales
=
quantize_tensor
(
tensor
,
w_bit
,
linear_dtype
,
comfyui_mode
)
weight_global_scale
=
None
# Replace original tensor and store scales
# Replace original tensor and store scales
weights
[
key
]
=
w_q
weights
[
key
]
=
w_q
...
@@ -422,6 +448,8 @@ def quantize_model(
...
@@ -422,6 +448,8 @@ def quantize_model(
weights
[
key
.
replace
(
".weight"
,
".scale_weight"
)]
=
scales
weights
[
key
.
replace
(
".weight"
,
".scale_weight"
)]
=
scales
else
:
else
:
weights
[
key
+
"_scale"
]
=
scales
weights
[
key
+
"_scale"
]
=
scales
if
weight_global_scale
:
weights
[
key
+
"_global_scale"
]
=
weight_global_scale
quantized_tensor_size
=
w_q
.
numel
()
*
w_q
.
element_size
()
quantized_tensor_size
=
w_q
.
numel
()
*
w_q
.
element_size
()
scale_size
=
scales
.
numel
()
*
scales
.
element_size
()
scale_size
=
scales
.
numel
()
*
scales
.
element_size
()
...
@@ -622,6 +650,7 @@ def convert_weights(args):
...
@@ -622,6 +650,7 @@ def convert_weights(args):
non_linear_dtype
=
args
.
non_linear_dtype
,
non_linear_dtype
=
args
.
non_linear_dtype
,
comfyui_mode
=
args
.
comfyui_mode
,
comfyui_mode
=
args
.
comfyui_mode
,
comfyui_keys
=
args
.
comfyui_keys
,
comfyui_keys
=
args
.
comfyui_keys
,
linear_quant_type
=
args
.
linear_quant_type
,
)
)
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
...
@@ -793,6 +822,12 @@ def main():
...
@@ -793,6 +822,12 @@ def main():
choices
=
[
"torch.int8"
,
"torch.float8_e4m3fn"
],
choices
=
[
"torch.int8"
,
"torch.float8_e4m3fn"
],
help
=
"Data type for linear"
,
help
=
"Data type for linear"
,
)
)
parser
.
add_argument
(
"--linear_quant_type"
,
type
=
str
,
choices
=
[
"INT8"
,
"FP8"
,
"NVFP4"
,
"MXFP4"
,
"MXFP6"
,
"MXFP8"
],
help
=
"Data type for linear"
,
)
parser
.
add_argument
(
parser
.
add_argument
(
"--non_linear_dtype"
,
"--non_linear_dtype"
,
type
=
str
,
type
=
str
,
...
...
tools/convert/quant/__init__.py
0 → 100644
View file @
e106ff67
from
.quant
import
*
tools/convert/quant/quant.py
0 → 100644
View file @
e106ff67
from
abc
import
ABCMeta
import
torch
from
qtorch.quant
import
float_quantize
try
:
from
lightx2v.utils.registry_factory
import
CONVERT_WEIGHT_REGISTER
from
lightx2v_kernel.gemm
import
scaled_mxfp4_quant
,
scaled_mxfp6_quant
,
scaled_mxfp8_quant
,
scaled_nvfp4_quant
except
ImportError
:
pass
class
QuantTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight
):
if
weight
.
dim
()
!=
2
:
raise
ValueError
(
f
"Only 2D tensors supported. Got
{
weight
.
dim
()
}
D tensor"
)
if
torch
.
isnan
(
weight
).
any
():
raise
ValueError
(
"Tensor contains NaN values"
)
self
.
weight_quant_func
=
None
self
.
extra
=
{}
@
CONVERT_WEIGHT_REGISTER
(
"INT8"
)
class
QuantWeightINT8
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_int8_weight
@
torch
.
no_grad
()
def
load_int8_weight
(
self
,
w
):
org_w_shape
=
w
.
shape
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
qmin
,
qmax
=
-
128
,
127
scales
=
max_val
/
qmax
w_q
=
torch
.
clamp
(
torch
.
round
(
w
/
scales
),
qmin
,
qmax
).
to
(
torch
.
int8
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w_q
).
sum
()
==
0
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"FP8"
)
class
QuantWeightFP8
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_fp8_weight
@
torch
.
no_grad
()
def
load_fp8_weight
(
self
,
w
):
org_w_shape
=
w
.
shape
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
qmin
,
qmax
=
finfo
.
min
,
finfo
.
max
scales
=
max_val
/
qmax
scaled_tensor
=
w
/
scales
scaled_tensor
=
torch
.
clip
(
scaled_tensor
,
qmin
,
qmax
)
w_q
=
float_quantize
(
scaled_tensor
.
float
(),
4
,
3
,
rounding
=
"nearest"
).
to
(
torch
.
float8_e4m3fn
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w_q
).
sum
()
==
0
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"MXFP4"
)
class
QuantWeightMxFP4
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_mxfp4_weight
@
torch
.
no_grad
()
def
load_mxfp4_weight
(
self
,
w
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
w_q
,
scales
=
scaled_mxfp4_quant
(
w
)
w_q
,
scales
=
w_q
.
to
(
device
),
scales
.
to
(
device
)
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"MXFP6"
)
class
QuantWeightMxFP6
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_mxfp6_weight
@
torch
.
no_grad
()
def
load_mxfp6_weight
(
self
,
w
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
w_q
,
scales
=
scaled_mxfp6_quant
(
w
)
w_q
,
scales
=
w_q
.
to
(
device
),
scales
.
to
(
device
)
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"MXFP8"
)
class
QuantWeightMxFP8
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_mxfp8_weight
@
torch
.
no_grad
()
def
load_mxfp8_weight
(
self
,
w
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
w_q
,
scales
=
scaled_mxfp8_quant
(
w
)
w_q
,
scales
=
w_q
.
to
(
device
),
scales
.
to
(
device
)
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"NVFP4"
)
class
QuantWeightNVFP4
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_fp4_weight
@
torch
.
no_grad
()
def
load_fp4_weight
(
self
,
w
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
weight_global_scale
=
(
2688.0
/
torch
.
max
(
torch
.
abs
(
w
))).
to
(
torch
.
float32
)
w_q
,
scales
=
scaled_nvfp4_quant
(
w
,
weight_global_scale
)
w_q
,
scales
=
w_q
.
to
(
device
),
scales
.
to
(
device
)
self
.
extra
[
"weight_global_scale"
]
=
weight_global_scale
.
to
(
device
)
return
w_q
,
scales
,
self
.
extra
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment