Unverified Commit bbd164c6 authored by Bilang ZHANG's avatar Bilang ZHANG Committed by GitHub
Browse files

update convert (#481)

--linear_dtype and --linear_quant_dtype unify as --linear_type
parent 4beb6ebc
...@@ -27,6 +27,11 @@ sys.path.append(str(Path(__file__).parent.parent.parent)) ...@@ -27,6 +27,11 @@ sys.path.append(str(Path(__file__).parent.parent.parent))
from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER from lightx2v.utils.registry_factory import CONVERT_WEIGHT_REGISTER
from tools.convert.quant import * from tools.convert.quant import *
dtype_mapping = {
"int8": torch.int8,
"fp8": torch.float8_e4m3fn,
}
def get_key_mapping_rules(direction, model_type): def get_key_mapping_rules(direction, model_type):
if model_type == "wan_dit": if model_type == "wan_dit":
...@@ -306,59 +311,6 @@ def get_key_mapping_rules(direction, model_type): ...@@ -306,59 +311,6 @@ def get_key_mapping_rules(direction, model_type):
raise ValueError(f"Unsupported model type: {model_type}") raise ValueError(f"Unsupported model type: {model_type}")
def quantize_tensor(w, w_bit=8, dtype=torch.int8, comfyui_mode=False):
"""
Quantize a 2D tensor to specified bit width using symmetric min-max quantization
Args:
w: Input tensor to quantize (must be 2D)
w_bit: Quantization bit width (default: 8)
Returns:
quantized: Quantized tensor (int8)
scales: Scaling factors per row
"""
if w.dim() != 2:
raise ValueError(f"Only 2D tensors supported. Got {w.dim()}D tensor")
if torch.isnan(w).any():
raise ValueError("Tensor contains NaN values")
if w_bit != 8:
raise ValueError("Only support 8 bits")
org_w_shape = w.shape
# Calculate quantization parameters
if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
if dtype == torch.float8_e4m3fn:
finfo = torch.finfo(dtype)
qmin, qmax = finfo.min, finfo.max
elif dtype == torch.int8:
qmin, qmax = -128, 127
# Quantize tensor
scales = max_val / qmax
if dtype == torch.float8_e4m3fn:
from qtorch.quant import float_quantize
scaled_tensor = w / scales
scaled_tensor = torch.clip(scaled_tensor, qmin, qmax)
w_q = float_quantize(scaled_tensor.float(), 4, 3, rounding="nearest").to(dtype)
else:
w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(dtype)
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0
if not comfyui_mode:
scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales
def quantize_model( def quantize_model(
weights, weights,
w_bit=8, w_bit=8,
...@@ -366,11 +318,10 @@ def quantize_model( ...@@ -366,11 +318,10 @@ def quantize_model(
adapter_keys=None, adapter_keys=None,
key_idx=2, key_idx=2,
ignore_key=None, ignore_key=None,
linear_dtype=torch.int8, linear_type="int8",
non_linear_dtype=torch.float, non_linear_dtype=torch.float,
comfyui_mode=False, comfyui_mode=False,
comfyui_keys=[], comfyui_keys=[],
linear_quant_type=None,
): ):
""" """
Quantize model weights in-place Quantize model weights in-place
...@@ -435,13 +386,9 @@ def quantize_model( ...@@ -435,13 +386,9 @@ def quantize_model(
original_size += original_tensor_size original_size += original_tensor_size
# Quantize tensor and store results # Quantize tensor and store results
if linear_quant_type: quantizer = CONVERT_WEIGHT_REGISTER[linear_type](tensor)
quantizer = CONVERT_WEIGHT_REGISTER[linear_quant_type](tensor) w_q, scales, extra = quantizer.weight_quant_func(tensor, comfyui_mode)
w_q, scales, extra = quantizer.weight_quant_func(tensor) weight_global_scale = extra.get("weight_global_scale", None) # For nvfp4
weight_global_scale = extra.get("weight_global_scale", None) # For nvfp4
else:
w_q, scales = quantize_tensor(tensor, w_bit, linear_dtype, comfyui_mode)
weight_global_scale = None
# Replace original tensor and store scales # Replace original tensor and store scales
weights[key] = w_q weights[key] = w_q
...@@ -637,6 +584,7 @@ def convert_weights(args): ...@@ -637,6 +584,7 @@ def convert_weights(args):
if args.quantized: if args.quantized:
if args.full_quantized and args.comfyui_mode: if args.full_quantized and args.comfyui_mode:
logger.info("Quant all tensors...") logger.info("Quant all tensors...")
assert args.linear_dtype, f"Error: only support 'torch.int8' and 'torch.float8_e4m3fn'."
for k in converted_weights.keys(): for k in converted_weights.keys():
converted_weights[k] = converted_weights[k].float().to(args.linear_dtype) converted_weights[k] = converted_weights[k].float().to(args.linear_dtype)
else: else:
...@@ -647,11 +595,10 @@ def convert_weights(args): ...@@ -647,11 +595,10 @@ def convert_weights(args):
adapter_keys=args.adapter_keys, adapter_keys=args.adapter_keys,
key_idx=args.key_idx, key_idx=args.key_idx,
ignore_key=args.ignore_key, ignore_key=args.ignore_key,
linear_dtype=args.linear_dtype, linear_type=args.linear_type,
non_linear_dtype=args.non_linear_dtype, non_linear_dtype=args.non_linear_dtype,
comfyui_mode=args.comfyui_mode, comfyui_mode=args.comfyui_mode,
comfyui_keys=args.comfyui_keys, comfyui_keys=args.comfyui_keys,
linear_quant_type=args.linear_quant_type,
) )
os.makedirs(args.output, exist_ok=True) os.makedirs(args.output, exist_ok=True)
...@@ -818,16 +765,10 @@ def main(): ...@@ -818,16 +765,10 @@ def main():
help="Device to use for quantization (cpu/cuda)", help="Device to use for quantization (cpu/cuda)",
) )
parser.add_argument( parser.add_argument(
"--linear_dtype", "--linear_type",
type=str,
choices=["torch.int8", "torch.float8_e4m3fn"],
help="Data type for linear",
)
parser.add_argument(
"--linear_quant_type",
type=str, type=str,
choices=["INT8", "FP8", "NVFP4", "MXFP4", "MXFP6", "MXFP8"], choices=["int8", "fp8", "nvfp4", "mxfp4", "mxfp6", "mxfp8"],
help="Data type for linear", help="Quant type for linear",
) )
parser.add_argument( parser.add_argument(
"--non_linear_dtype", "--non_linear_dtype",
...@@ -870,7 +811,7 @@ def main(): ...@@ -870,7 +811,7 @@ def main():
logger.warning("--chunk_size is ignored when using --single_file option.") logger.warning("--chunk_size is ignored when using --single_file option.")
if args.quantized: if args.quantized:
args.linear_dtype = eval(args.linear_dtype) args.linear_dtype = dtype_mapping.get(args.linear_type, None)
args.non_linear_dtype = eval(args.non_linear_dtype) args.non_linear_dtype = eval(args.non_linear_dtype)
model_type_keys_map = { model_type_keys_map = {
......
...@@ -22,16 +22,19 @@ class QuantTemplate(metaclass=ABCMeta): ...@@ -22,16 +22,19 @@ class QuantTemplate(metaclass=ABCMeta):
self.extra = {} self.extra = {}
@CONVERT_WEIGHT_REGISTER("INT8") @CONVERT_WEIGHT_REGISTER("int8")
class QuantWeightINT8(QuantTemplate): class QuantWeightINT8(QuantTemplate):
def __init__(self, weight): def __init__(self, weight):
super().__init__(weight) super().__init__(weight)
self.weight_quant_func = self.load_int8_weight self.weight_quant_func = self.load_int8_weight
@torch.no_grad() @torch.no_grad()
def load_int8_weight(self, w): def load_int8_weight(self, w, comfyui_mode=False):
org_w_shape = w.shape org_w_shape = w.shape
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
qmin, qmax = -128, 127 qmin, qmax = -128, 127
scales = max_val / qmax scales = max_val / qmax
w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8) w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8)
...@@ -39,22 +42,26 @@ class QuantWeightINT8(QuantTemplate): ...@@ -39,22 +42,26 @@ class QuantWeightINT8(QuantTemplate):
assert torch.isnan(scales).sum() == 0 assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0 assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1) if not comfyui_mode:
w_q = w_q.reshape(org_w_shape) scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales, self.extra return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("FP8") @CONVERT_WEIGHT_REGISTER("fp8")
class QuantWeightFP8(QuantTemplate): class QuantWeightFP8(QuantTemplate):
def __init__(self, weight): def __init__(self, weight):
super().__init__(weight) super().__init__(weight)
self.weight_quant_func = self.load_fp8_weight self.weight_quant_func = self.load_fp8_weight
@torch.no_grad() @torch.no_grad()
def load_fp8_weight(self, w): def load_fp8_weight(self, w, comfyui_mode=False):
org_w_shape = w.shape org_w_shape = w.shape
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) if not comfyui_mode:
max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5)
else:
max_val = w.abs().max()
finfo = torch.finfo(torch.float8_e4m3fn) finfo = torch.finfo(torch.float8_e4m3fn)
qmin, qmax = finfo.min, finfo.max qmin, qmax = finfo.min, finfo.max
scales = max_val / qmax scales = max_val / qmax
...@@ -65,20 +72,21 @@ class QuantWeightFP8(QuantTemplate): ...@@ -65,20 +72,21 @@ class QuantWeightFP8(QuantTemplate):
assert torch.isnan(scales).sum() == 0 assert torch.isnan(scales).sum() == 0
assert torch.isnan(w_q).sum() == 0 assert torch.isnan(w_q).sum() == 0
scales = scales.view(org_w_shape[0], -1) if not comfyui_mode:
w_q = w_q.reshape(org_w_shape) scales = scales.view(org_w_shape[0], -1)
w_q = w_q.reshape(org_w_shape)
return w_q, scales, self.extra return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP4") @CONVERT_WEIGHT_REGISTER("mxfp4")
class QuantWeightMxFP4(QuantTemplate): class QuantWeightMxFP4(QuantTemplate):
def __init__(self, weight): def __init__(self, weight):
super().__init__(weight) super().__init__(weight)
self.weight_quant_func = self.load_mxfp4_weight self.weight_quant_func = self.load_mxfp4_weight
@torch.no_grad() @torch.no_grad()
def load_mxfp4_weight(self, w): def load_mxfp4_weight(self, w, comfyui_mode=False):
device = w.device device = w.device
w = w.cuda().to(torch.bfloat16) w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp4_quant(w) w_q, scales = scaled_mxfp4_quant(w)
...@@ -86,14 +94,14 @@ class QuantWeightMxFP4(QuantTemplate): ...@@ -86,14 +94,14 @@ class QuantWeightMxFP4(QuantTemplate):
return w_q, scales, self.extra return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP6") @CONVERT_WEIGHT_REGISTER("mxfp6")
class QuantWeightMxFP6(QuantTemplate): class QuantWeightMxFP6(QuantTemplate):
def __init__(self, weight): def __init__(self, weight):
super().__init__(weight) super().__init__(weight)
self.weight_quant_func = self.load_mxfp6_weight self.weight_quant_func = self.load_mxfp6_weight
@torch.no_grad() @torch.no_grad()
def load_mxfp6_weight(self, w): def load_mxfp6_weight(self, w, comfyui_mode=False):
device = w.device device = w.device
w = w.cuda().to(torch.bfloat16) w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp6_quant(w) w_q, scales = scaled_mxfp6_quant(w)
...@@ -101,14 +109,14 @@ class QuantWeightMxFP6(QuantTemplate): ...@@ -101,14 +109,14 @@ class QuantWeightMxFP6(QuantTemplate):
return w_q, scales, self.extra return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("MXFP8") @CONVERT_WEIGHT_REGISTER("mxfp8")
class QuantWeightMxFP8(QuantTemplate): class QuantWeightMxFP8(QuantTemplate):
def __init__(self, weight): def __init__(self, weight):
super().__init__(weight) super().__init__(weight)
self.weight_quant_func = self.load_mxfp8_weight self.weight_quant_func = self.load_mxfp8_weight
@torch.no_grad() @torch.no_grad()
def load_mxfp8_weight(self, w): def load_mxfp8_weight(self, w, comfyui_mode=False):
device = w.device device = w.device
w = w.cuda().to(torch.bfloat16) w = w.cuda().to(torch.bfloat16)
w_q, scales = scaled_mxfp8_quant(w) w_q, scales = scaled_mxfp8_quant(w)
...@@ -116,14 +124,14 @@ class QuantWeightMxFP8(QuantTemplate): ...@@ -116,14 +124,14 @@ class QuantWeightMxFP8(QuantTemplate):
return w_q, scales, self.extra return w_q, scales, self.extra
@CONVERT_WEIGHT_REGISTER("NVFP4") @CONVERT_WEIGHT_REGISTER("nvfp4")
class QuantWeightNVFP4(QuantTemplate): class QuantWeightNVFP4(QuantTemplate):
def __init__(self, weight): def __init__(self, weight):
super().__init__(weight) super().__init__(weight)
self.weight_quant_func = self.load_fp4_weight self.weight_quant_func = self.load_fp4_weight
@torch.no_grad() @torch.no_grad()
def load_fp4_weight(self, w): def load_fp4_weight(self, w, comfyui_mode=False):
device = w.device device = w.device
w = w.cuda().to(torch.bfloat16) w = w.cuda().to(torch.bfloat16)
weight_global_scale = (2688.0 / torch.max(torch.abs(w))).to(torch.float32) weight_global_scale = (2688.0 / torch.max(torch.abs(w))).to(torch.float32)
......
...@@ -5,7 +5,7 @@ A powerful model weight conversion tool that supports format conversion, quantiz ...@@ -5,7 +5,7 @@ A powerful model weight conversion tool that supports format conversion, quantiz
## Main Features ## Main Features
- **Format Conversion**: Support PyTorch (.pth) and SafeTensors (.safetensors) format conversion - **Format Conversion**: Support PyTorch (.pth) and SafeTensors (.safetensors) format conversion
- **Model Quantization**: Support INT8 and FP8 quantization to significantly reduce model size - **Model Quantization**: Support INT8, FP8, NVFP4, MXFP4, MXFP6 and MXFP8 quantization to significantly reduce model size
- **Architecture Conversion**: Support conversion between LightX2V and Diffusers architectures - **Architecture Conversion**: Support conversion between LightX2V and Diffusers architectures
- **LoRA Merging**: Support loading and merging multiple LoRA formats - **LoRA Merging**: Support loading and merging multiple LoRA formats
- **Multi-Model Support**: Support Wan DiT, Qwen Image DiT, T5, CLIP, etc. - **Multi-Model Support**: Support Wan DiT, Qwen Image DiT, T5, CLIP, etc.
...@@ -42,16 +42,21 @@ A powerful model weight conversion tool that supports format conversion, quantiz ...@@ -42,16 +42,21 @@ A powerful model weight conversion tool that supports format conversion, quantiz
- `--quantized`: Enable quantization - `--quantized`: Enable quantization
- `--bits`: Quantization bit width, currently only supports 8-bit - `--bits`: Quantization bit width, currently only supports 8-bit
- `--linear_dtype`: Linear layer quantization type - `--linear_type`: Linear layer quantization type
- `torch.int8`: INT8 quantization - `int8`: INT8 quantization (torch.int8)
- `torch.float8_e4m3fn`: FP8 quantization - `fp8`: FP8 quantization (torch.float8_e4m3fn)
- `nvfp4`: NVFP4 quantization
- `mxfp4`: MXFP4 quantization
- `mxfp6`: MXFP6 quantization
- `mxfp8`: MXFP8 quantization
- `--non_linear_dtype`: Non-linear layer data type - `--non_linear_dtype`: Non-linear layer data type
- `torch.bfloat16`: BF16 - `torch.bfloat16`: BF16
- `torch.float16`: FP16 - `torch.float16`: FP16
- `torch.float32`: FP32 (default) - `torch.float32`: FP32 (default)
- `--device`: Device for quantization, `cpu` or `cuda` (default) - `--device`: Device for quantization, `cpu` or `cuda` (default)
- `--comfyui_mode`: ComfyUI compatible mode - `--comfyui_mode`: ComfyUI compatible mode (only int8 and fp8)
- `--full_quantized`: Full quantization mode (effective in ComfyUI mode) - `--full_quantized`: Full quantization mode (effective in ComfyUI mode)
For nvfp4, mxfp4, mxfp6 and mxfp8, please install them fllowing LightX2V/lightx2v_kernel/README.md.
### LoRA Parameters ### LoRA Parameters
...@@ -105,7 +110,7 @@ python converter.py \ ...@@ -105,7 +110,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan_int8 \ --output_name wan_int8 \
--linear_dtype torch.int8 \ --linear_type int8 \
--model_type wan_dit \ --model_type wan_dit \
--quantized \ --quantized \
--save_by_block --save_by_block
...@@ -118,7 +123,7 @@ python converter.py \ ...@@ -118,7 +123,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan2.1_i2v_480p_int8_lightx2v \ --output_name wan2.1_i2v_480p_int8_lightx2v \
--linear_dtype torch.int8 \ --linear_type int8 \
--model_type wan_dit \ --model_type wan_dit \
--quantized \ --quantized \
--single_file --single_file
...@@ -133,7 +138,7 @@ python converter.py \ ...@@ -133,7 +138,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan_fp8 \ --output_name wan_fp8 \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--non_linear_dtype torch.bfloat16 \ --non_linear_dtype torch.bfloat16 \
--model_type wan_dit \ --model_type wan_dit \
--quantized \ --quantized \
...@@ -147,7 +152,7 @@ python converter.py \ ...@@ -147,7 +152,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v \ --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--non_linear_dtype torch.bfloat16 \ --non_linear_dtype torch.bfloat16 \
--model_type wan_dit \ --model_type wan_dit \
--quantized \ --quantized \
...@@ -161,7 +166,7 @@ python converter.py \ ...@@ -161,7 +166,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \ --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--non_linear_dtype torch.bfloat16 \ --non_linear_dtype torch.bfloat16 \
--model_type wan_dit \ --model_type wan_dit \
--quantized \ --quantized \
...@@ -176,7 +181,7 @@ python converter.py \ ...@@ -176,7 +181,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .safetensors \ --output_ext .safetensors \
--output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \ --output_name wan2.1_i2v_480p_scaled_fp8_e4m3_lightx2v_comfyui \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--non_linear_dtype torch.bfloat16 \ --non_linear_dtype torch.bfloat16 \
--model_type wan_dit \ --model_type wan_dit \
--quantized \ --quantized \
...@@ -196,7 +201,7 @@ python converter.py \ ...@@ -196,7 +201,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .pth \ --output_ext .pth \
--output_name models_t5_umt5-xxl-enc-int8 \ --output_name models_t5_umt5-xxl-enc-int8 \
--linear_dtype torch.int8 \ --linear_type int8 \
--non_linear_dtype torch.bfloat16 \ --non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \ --model_type wan_t5 \
--quantized --quantized
...@@ -209,7 +214,7 @@ python converter.py \ ...@@ -209,7 +214,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .pth \ --output_ext .pth \
--output_name models_t5_umt5-xxl-enc-fp8 \ --output_name models_t5_umt5-xxl-enc-fp8 \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--non_linear_dtype torch.bfloat16 \ --non_linear_dtype torch.bfloat16 \
--model_type wan_t5 \ --model_type wan_t5 \
--quantized --quantized
...@@ -224,7 +229,7 @@ python converter.py \ ...@@ -224,7 +229,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .pth \ --output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8 \ --output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-int8 \
--linear_dtype torch.int8 \ --linear_type int8 \
--non_linear_dtype torch.float16 \ --non_linear_dtype torch.float16 \
--model_type wan_clip \ --model_type wan_clip \
--quantized --quantized
...@@ -237,7 +242,7 @@ python converter.py \ ...@@ -237,7 +242,7 @@ python converter.py \
--output /path/to/output \ --output /path/to/output \
--output_ext .pth \ --output_ext .pth \
--output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8 \ --output_name models_clip_open-clip-xlm-roberta-large-vit-huge-14-fp8 \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--non_linear_dtype torch.float16 \ --non_linear_dtype torch.float16 \
--model_type wan_clip \ --model_type wan_clip \
--quantized --quantized
...@@ -318,7 +323,7 @@ python converter.py \ ...@@ -318,7 +323,7 @@ python converter.py \
--lora_path /path/to/lora.safetensors \ --lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \ --lora_strength 1.0 \
--quantized \ --quantized \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--single_file --single_file
``` ```
...@@ -333,7 +338,7 @@ python converter.py \ ...@@ -333,7 +338,7 @@ python converter.py \
--lora_path /path/to/lora.safetensors \ --lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \ --lora_strength 1.0 \
--quantized \ --quantized \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--single_file \ --single_file \
--comfyui_mode --comfyui_mode
``` ```
...@@ -349,7 +354,7 @@ python converter.py \ ...@@ -349,7 +354,7 @@ python converter.py \
--lora_path /path/to/lora.safetensors \ --lora_path /path/to/lora.safetensors \
--lora_strength 1.0 \ --lora_strength 1.0 \
--quantized \ --quantized \
--linear_dtype torch.float8_e4m3fn \ --linear_type fp8 \
--single_file \ --single_file \
--comfyui_mode \ --comfyui_mode \
--full_quantized --full_quantized
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment