Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
bbd164c6
Unverified
Commit
bbd164c6
authored
Nov 21, 2025
by
Bilang ZHANG
Committed by
GitHub
Nov 21, 2025
Browse files
update convert (#481)
--linear_dtype and --linear_quant_dtype unify as --linear_type
parent
4beb6ebc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
64 additions
and
110 deletions
+64
-110
tools/convert/converter.py
tools/convert/converter.py
+15
-74
tools/convert/quant/quant.py
tools/convert/quant/quant.py
+26
-18
tools/convert/readme.md
tools/convert/readme.md
+23
-18
No files found.
tools/convert/converter.py
View file @
bbd164c6
...
...
@@ -27,6 +27,11 @@ sys.path.append(str(Path(__file__).parent.parent.parent))
from
lightx2v.utils.registry_factory
import
CONVERT_WEIGHT_REGISTER
from
tools.convert.quant
import
*
dtype_mapping
=
{
"int8"
:
torch
.
int8
,
"fp8"
:
torch
.
float8_e4m3fn
,
}
def
get_key_mapping_rules
(
direction
,
model_type
):
if
model_type
==
"wan_dit"
:
...
...
@@ -306,59 +311,6 @@ def get_key_mapping_rules(direction, model_type):
raise
ValueError
(
f
"Unsupported model type:
{
model_type
}
"
)
def
quantize_tensor
(
w
,
w_bit
=
8
,
dtype
=
torch
.
int8
,
comfyui_mode
=
False
):
"""
Quantize a 2D tensor to specified bit width using symmetric min-max quantization
Args:
w: Input tensor to quantize (must be 2D)
w_bit: Quantization bit width (default: 8)
Returns:
quantized: Quantized tensor (int8)
scales: Scaling factors per row
"""
if
w
.
dim
()
!=
2
:
raise
ValueError
(
f
"Only 2D tensors supported. Got
{
w
.
dim
()
}
D tensor"
)
if
torch
.
isnan
(
w
).
any
():
raise
ValueError
(
"Tensor contains NaN values"
)
if
w_bit
!=
8
:
raise
ValueError
(
"Only support 8 bits"
)
org_w_shape
=
w
.
shape
# Calculate quantization parameters
if
not
comfyui_mode
:
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
else
:
max_val
=
w
.
abs
().
max
()
if
dtype
==
torch
.
float8_e4m3fn
:
finfo
=
torch
.
finfo
(
dtype
)
qmin
,
qmax
=
finfo
.
min
,
finfo
.
max
elif
dtype
==
torch
.
int8
:
qmin
,
qmax
=
-
128
,
127
# Quantize tensor
scales
=
max_val
/
qmax
if
dtype
==
torch
.
float8_e4m3fn
:
from
qtorch.quant
import
float_quantize
scaled_tensor
=
w
/
scales
scaled_tensor
=
torch
.
clip
(
scaled_tensor
,
qmin
,
qmax
)
w_q
=
float_quantize
(
scaled_tensor
.
float
(),
4
,
3
,
rounding
=
"nearest"
).
to
(
dtype
)
else
:
w_q
=
torch
.
clamp
(
torch
.
round
(
w
/
scales
),
qmin
,
qmax
).
to
(
dtype
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w_q
).
sum
()
==
0
if
not
comfyui_mode
:
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
return
w_q
,
scales
def
quantize_model
(
weights
,
w_bit
=
8
,
...
...
@@ -366,11 +318,10 @@ def quantize_model(
adapter_keys
=
None
,
key_idx
=
2
,
ignore_key
=
None
,
linear_
d
type
=
torch
.
int8
,
linear_type
=
"
int8
"
,
non_linear_dtype
=
torch
.
float
,
comfyui_mode
=
False
,
comfyui_keys
=
[],
linear_quant_type
=
None
,
):
"""
Quantize model weights in-place
...
...
@@ -435,13 +386,9 @@ def quantize_model(
original_size
+=
original_tensor_size
# Quantize tensor and store results
if
linear_quant_type
:
quantizer
=
CONVERT_WEIGHT_REGISTER
[
linear_quant_type
](
tensor
)
w_q
,
scales
,
extra
=
quantizer
.
weight_quant_func
(
tensor
)
weight_global_scale
=
extra
.
get
(
"weight_global_scale"
,
None
)
# For nvfp4
else
:
w_q
,
scales
=
quantize_tensor
(
tensor
,
w_bit
,
linear_dtype
,
comfyui_mode
)
weight_global_scale
=
None
quantizer
=
CONVERT_WEIGHT_REGISTER
[
linear_type
](
tensor
)
w_q
,
scales
,
extra
=
quantizer
.
weight_quant_func
(
tensor
,
comfyui_mode
)
weight_global_scale
=
extra
.
get
(
"weight_global_scale"
,
None
)
# For nvfp4
# Replace original tensor and store scales
weights
[
key
]
=
w_q
...
...
@@ -637,6 +584,7 @@ def convert_weights(args):
if
args
.
quantized
:
if
args
.
full_quantized
and
args
.
comfyui_mode
:
logger
.
info
(
"Quant all tensors..."
)
assert
args
.
linear_dtype
,
f
"Error: only support 'torch.int8' and 'torch.float8_e4m3fn'."
for
k
in
converted_weights
.
keys
():
converted_weights
[
k
]
=
converted_weights
[
k
].
float
().
to
(
args
.
linear_dtype
)
else
:
...
...
@@ -647,11 +595,10 @@ def convert_weights(args):
adapter_keys
=
args
.
adapter_keys
,
key_idx
=
args
.
key_idx
,
ignore_key
=
args
.
ignore_key
,
linear_
d
type
=
args
.
linear_
d
type
,
linear_type
=
args
.
linear_type
,
non_linear_dtype
=
args
.
non_linear_dtype
,
comfyui_mode
=
args
.
comfyui_mode
,
comfyui_keys
=
args
.
comfyui_keys
,
linear_quant_type
=
args
.
linear_quant_type
,
)
os
.
makedirs
(
args
.
output
,
exist_ok
=
True
)
...
...
@@ -818,16 +765,10 @@ def main():
help
=
"Device to use for quantization (cpu/cuda)"
,
)
parser
.
add_argument
(
"--linear_dtype"
,
type
=
str
,
choices
=
[
"torch.int8"
,
"torch.float8_e4m3fn"
],
help
=
"Data type for linear"
,
)
parser
.
add_argument
(
"--linear_quant_type"
,
"--linear_type"
,
type
=
str
,
choices
=
[
"
INT
8"
,
"
FP
8"
,
"
NVFP
4"
,
"
MXFP
4"
,
"
MXFP
6"
,
"
MXFP
8"
],
help
=
"
Data
type for linear"
,
choices
=
[
"
int
8"
,
"
fp
8"
,
"
nvfp
4"
,
"
mxfp
4"
,
"
mxfp
6"
,
"
mxfp
8"
],
help
=
"
Quant
type for linear"
,
)
parser
.
add_argument
(
"--non_linear_dtype"
,
...
...
@@ -870,7 +811,7 @@ def main():
logger
.
warning
(
"--chunk_size is ignored when using --single_file option."
)
if
args
.
quantized
:
args
.
linear_dtype
=
eval
(
args
.
linear_
d
type
)
args
.
linear_dtype
=
dtype_mapping
.
get
(
args
.
linear_type
,
None
)
args
.
non_linear_dtype
=
eval
(
args
.
non_linear_dtype
)
model_type_keys_map
=
{
...
...
tools/convert/quant/quant.py
View file @
bbd164c6
...
...
@@ -22,16 +22,19 @@ class QuantTemplate(metaclass=ABCMeta):
self
.
extra
=
{}
@
CONVERT_WEIGHT_REGISTER
(
"
INT
8"
)
@
CONVERT_WEIGHT_REGISTER
(
"
int
8"
)
class
QuantWeightINT8
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_int8_weight
@
torch
.
no_grad
()
def
load_int8_weight
(
self
,
w
):
def
load_int8_weight
(
self
,
w
,
comfyui_mode
=
False
):
org_w_shape
=
w
.
shape
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
if
not
comfyui_mode
:
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
else
:
max_val
=
w
.
abs
().
max
()
qmin
,
qmax
=
-
128
,
127
scales
=
max_val
/
qmax
w_q
=
torch
.
clamp
(
torch
.
round
(
w
/
scales
),
qmin
,
qmax
).
to
(
torch
.
int8
)
...
...
@@ -39,22 +42,26 @@ class QuantWeightINT8(QuantTemplate):
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w_q
).
sum
()
==
0
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
if
not
comfyui_mode
:
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"
FP
8"
)
@
CONVERT_WEIGHT_REGISTER
(
"
fp
8"
)
class
QuantWeightFP8
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_fp8_weight
@
torch
.
no_grad
()
def
load_fp8_weight
(
self
,
w
):
def
load_fp8_weight
(
self
,
w
,
comfyui_mode
=
False
):
org_w_shape
=
w
.
shape
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
if
not
comfyui_mode
:
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
).
clamp
(
min
=
1e-5
)
else
:
max_val
=
w
.
abs
().
max
()
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
qmin
,
qmax
=
finfo
.
min
,
finfo
.
max
scales
=
max_val
/
qmax
...
...
@@ -65,20 +72,21 @@ class QuantWeightFP8(QuantTemplate):
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w_q
).
sum
()
==
0
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
if
not
comfyui_mode
:
scales
=
scales
.
view
(
org_w_shape
[
0
],
-
1
)
w_q
=
w_q
.
reshape
(
org_w_shape
)
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"
MXFP
4"
)
@
CONVERT_WEIGHT_REGISTER
(
"
mxfp
4"
)
class
QuantWeightMxFP4
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_mxfp4_weight
@
torch
.
no_grad
()
def
load_mxfp4_weight
(
self
,
w
):
def
load_mxfp4_weight
(
self
,
w
,
comfyui_mode
=
False
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
w_q
,
scales
=
scaled_mxfp4_quant
(
w
)
...
...
@@ -86,14 +94,14 @@ class QuantWeightMxFP4(QuantTemplate):
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"
MXFP
6"
)
@
CONVERT_WEIGHT_REGISTER
(
"
mxfp
6"
)
class
QuantWeightMxFP6
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_mxfp6_weight
@
torch
.
no_grad
()
def
load_mxfp6_weight
(
self
,
w
):
def
load_mxfp6_weight
(
self
,
w
,
comfyui_mode
=
False
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
w_q
,
scales
=
scaled_mxfp6_quant
(
w
)
...
...
@@ -101,14 +109,14 @@ class QuantWeightMxFP6(QuantTemplate):
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"
MXFP
8"
)
@
CONVERT_WEIGHT_REGISTER
(
"
mxfp
8"
)
class
QuantWeightMxFP8
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_mxfp8_weight
@
torch
.
no_grad
()
def
load_mxfp8_weight
(
self
,
w
):
def
load_mxfp8_weight
(
self
,
w
,
comfyui_mode
=
False
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
w_q
,
scales
=
scaled_mxfp8_quant
(
w
)
...
...
@@ -116,14 +124,14 @@ class QuantWeightMxFP8(QuantTemplate):
return
w_q
,
scales
,
self
.
extra
@
CONVERT_WEIGHT_REGISTER
(
"
NVFP
4"
)
@
CONVERT_WEIGHT_REGISTER
(
"
nvfp
4"
)
class
QuantWeightNVFP4
(
QuantTemplate
):
def
__init__
(
self
,
weight
):
super
().
__init__
(
weight
)
self
.
weight_quant_func
=
self
.
load_fp4_weight
@
torch
.
no_grad
()
def
load_fp4_weight
(
self
,
w
):
def
load_fp4_weight
(
self
,
w
,
comfyui_mode
=
False
):
device
=
w
.
device
w
=
w
.
cuda
().
to
(
torch
.
bfloat16
)
weight_global_scale
=
(
2688.0
/
torch
.
max
(
torch
.
abs
(
w
))).
to
(
torch
.
float32
)
...
...
tools/convert/readme.md
View file @
bbd164c6
...
...
@@ -5,7 +5,7 @@ A powerful model weight conversion tool that supports format conversion, quantiz
## Main Features
-
**Format Conversion**
: Support PyTorch (.pth) and SafeTensors (.safetensors) format conversion
-
**Model Quantization**
: Support INT8 and FP8 quantization to significantly reduce model size
-
**Model Quantization**
: Support INT8
, FP8, NVFP4, MXFP4, MXFP6
and
MX
FP8 quantization to significantly reduce model size
-
**Architecture Conversion**
: Support conversion between LightX2V and Diffusers architectures
-
**LoRA Merging**
: Support loading and merging multiple LoRA formats
-
**Multi-Model Support**
: Support Wan DiT, Qwen Image DiT, T5, CLIP, etc.
...
...
@@ -42,16 +42,21 @@ A powerful model weight conversion tool that supports format conversion, quantiz
-
`--quantized`
: Enable quantization
-
`--bits`
: Quantization bit width, currently only supports 8-bit
-
`--linear_dtype`
: Linear layer quantization type
-
`torch.int8`
: INT8 quantization
-
`torch.float8_e4m3fn`
: FP8 quantization
-
`--linear_type`
: Linear layer quantization type
-
`int8`
: INT8 quantization (torch.int8)
-
`fp8`
: FP8 quantization (torch.float8_e4m3fn)
-
`nvfp4`
: NVFP4 quantization
-
`mxfp4`
: MXFP4 quantization
-
`mxfp6`
: MXFP6 quantization
-
`mxfp8`
: MXFP8 quantization
-
`--non_linear_dtype`
: Non-linear layer data type
-
`torch.bfloat16`
: BF16
-
`torch.float16`
: FP16
-
`torch.float32`
: FP32 (default)
-
`--device`
: Device for quantization,
`cpu`
or
`cuda`
(default)
-
`--comfyui_mode`
: ComfyUI compatible mode
-
`--comfyui_mode`
: ComfyUI compatible mode
(only int8 and fp8)
-
`--full_quantized`
: Full quantization mode (effective in ComfyUI mode)
For nvfp4, mxfp4, mxfp6 and mxfp8, please install them fllowing LightX2V/lightx2v_kernel/README.md.
### LoRA Parameters
...
...
@@ -105,7 +110,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.safetensors
\
--output_name
wan_int8
\
--linear_
d
type
torch.
int8
\
--linear_type
int8
\
--model_type
wan_dit
\
--quantized
\
--save_by_block
...
...
@@ -118,7 +123,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.safetensors
\
--output_name
wan2.1_i2v_480p_int8_lightx2v
\
--linear_
d
type
torch.
int8
\
--linear_type
int8
\
--model_type
wan_dit
\
--quantized
\
--single_file
...
...
@@ -133,7 +138,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.safetensors
\
--output_name
wan_fp8
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--non_linear_dtype
torch.bfloat16
\
--model_type
wan_dit
\
--quantized
\
...
...
@@ -147,7 +152,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.safetensors
\
--output_name
wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--non_linear_dtype
torch.bfloat16
\
--model_type
wan_dit
\
--quantized
\
...
...
@@ -161,7 +166,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.safetensors
\
--output_name
wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--non_linear_dtype
torch.bfloat16
\
--model_type
wan_dit
\
--quantized
\
...
...
@@ -176,7 +181,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.safetensors
\
--output_name
wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--non_linear_dtype
torch.bfloat16
\
--model_type
wan_dit
\
--quantized
\
...
...
@@ -196,7 +201,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.pth
\
--output_name
models_t5_umt5-xxl-enc-int8
\
--linear_
d
type
torch.
int8
\
--linear_type
int8
\
--non_linear_dtype
torch.bfloat16
\
--model_type
wan_t5
\
--quantized
...
...
@@ -209,7 +214,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.pth
\
--output_name
models_t5_umt5-xxl-enc-fp8
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--non_linear_dtype
torch.bfloat16
\
--model_type
wan_t5
\
--quantized
...
...
@@ -224,7 +229,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.pth
\
--output_name
models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8
\
--linear_
d
type
torch.
int8
\
--linear_type
int8
\
--non_linear_dtype
torch.float16
\
--model_type
wan_clip
\
--quantized
...
...
@@ -237,7 +242,7 @@ python converter.py \
--output
/path/to/output
\
--output_ext
.pth
\
--output_name
models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--non_linear_dtype
torch.float16
\
--model_type
wan_clip
\
--quantized
...
...
@@ -318,7 +323,7 @@ python converter.py \
--lora_path
/path/to/lora.safetensors
\
--lora_strength
1.0
\
--quantized
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--single_file
```
...
...
@@ -333,7 +338,7 @@ python converter.py \
--lora_path
/path/to/lora.safetensors
\
--lora_strength
1.0
\
--quantized
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--single_file
\
--comfyui_mode
```
...
...
@@ -349,7 +354,7 @@ python converter.py \
--lora_path
/path/to/lora.safetensors
\
--lora_strength
1.0
\
--quantized
\
--linear_
d
type
torch.float8_e4m3fn
\
--linear_type
fp8
\
--single_file
\
--comfyui_mode
\
--full_quantized
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment