"docs/vscode:/vscode.git/clone" did not exist on "d7b20dd65d310f68d79d585fae4b268283d1f93a"
Commit 9121bad1 authored by gushiqiao's avatar gushiqiao
Browse files

fix.

parent 5c241f86
......@@ -2,12 +2,11 @@ import torch
from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops
# import sgl_kernel
import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from lightx2v.utils.envs import *
from loguru import logger
from safetensors import safe_open
try:
import q8_kernels.functional as Q8F
......@@ -21,9 +20,11 @@ except ImportError:
class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
self.weight_name = weight_name
self.bias_name = bias_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {}
@abstractmethod
......@@ -61,8 +62,8 @@ class MMWeightTemplate(metaclass=ABCMeta):
@MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t()
......@@ -90,8 +91,8 @@ class MMWeight(MMWeightTemplate):
@MM_WEIGHT_REGISTER("Default-Force-FP32")
class MMWeightForceFP32(MMWeight):
def __init__(self, weight_name, bias_name):
super().__init__(weight_name, bias_name)
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
def load(self, weight_dict):
super().load(weight_dict)
......@@ -102,7 +103,7 @@ class MMWeightForceFP32(MMWeight):
class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name)
super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
self.load_func = None
self.weight_need_transpose = True
......
......@@ -46,7 +46,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \
--dtype torch.float8_e4m3fn \
--model_type wan_dit
```
......@@ -70,7 +70,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \
--dtype torch.float8_e4m3fn \
--model_type hunyuan_dit
```
......
......@@ -46,7 +46,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth\
--output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \
--dtype torch.float8_e4m3fn \
--model_type wan_dit
```
......@@ -70,7 +70,7 @@ python converter.py \
--output /Path/To/output \
--output_ext .pth\
--output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \
--dtype torch.float8_e4m3fn \
--model_type hunyuan_dit
```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment