Commit 2ef8e74e authored by gushiqiao's avatar gushiqiao Committed by GitHub
Browse files

Merge pull request #62 from ModelTC/dev_fixbugs

fix.
parents 5c241f86 9121bad1
...@@ -2,12 +2,11 @@ import torch ...@@ -2,12 +2,11 @@ import torch
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
# import sgl_kernel import sgl_kernel
from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER
from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer
from lightx2v.utils.envs import * from lightx2v.utils.envs import *
from loguru import logger from loguru import logger
from safetensors import safe_open
try: try:
import q8_kernels.functional as Q8F import q8_kernels.functional as Q8F
...@@ -21,9 +20,11 @@ except ImportError: ...@@ -21,9 +20,11 @@ except ImportError:
class MMWeightTemplate(metaclass=ABCMeta): class MMWeightTemplate(metaclass=ABCMeta):
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
self.weight_name = weight_name self.weight_name = weight_name
self.bias_name = bias_name self.bias_name = bias_name
self.lazy_load = lazy_load
self.lazy_load_file = lazy_load_file
self.config = {} self.config = {}
@abstractmethod @abstractmethod
...@@ -61,8 +62,8 @@ class MMWeightTemplate(metaclass=ABCMeta): ...@@ -61,8 +62,8 @@ class MMWeightTemplate(metaclass=ABCMeta):
@MM_WEIGHT_REGISTER("Default") @MM_WEIGHT_REGISTER("Default")
class MMWeight(MMWeightTemplate): class MMWeight(MMWeightTemplate):
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
def load(self, weight_dict): def load(self, weight_dict):
self.weight = weight_dict[self.weight_name].t() self.weight = weight_dict[self.weight_name].t()
...@@ -90,8 +91,8 @@ class MMWeight(MMWeightTemplate): ...@@ -90,8 +91,8 @@ class MMWeight(MMWeightTemplate):
@MM_WEIGHT_REGISTER("Default-Force-FP32") @MM_WEIGHT_REGISTER("Default-Force-FP32")
class MMWeightForceFP32(MMWeight): class MMWeightForceFP32(MMWeight):
def __init__(self, weight_name, bias_name): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
def load(self, weight_dict): def load(self, weight_dict):
super().load(weight_dict) super().load(weight_dict)
...@@ -102,7 +103,7 @@ class MMWeightForceFP32(MMWeight): ...@@ -102,7 +103,7 @@ class MMWeightForceFP32(MMWeight):
class MMWeightQuantTemplate(MMWeightTemplate): class MMWeightQuantTemplate(MMWeightTemplate):
def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None): def __init__(self, weight_name, bias_name, lazy_load=False, lazy_load_file=None):
super().__init__(weight_name, bias_name) super().__init__(weight_name, bias_name, lazy_load, lazy_load_file)
self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale" self.weight_scale_name = self.weight_name.removesuffix(".weight") + ".weight_scale"
self.load_func = None self.load_func = None
self.weight_need_transpose = True self.weight_need_transpose = True
......
...@@ -46,7 +46,7 @@ python converter.py \ ...@@ -46,7 +46,7 @@ python converter.py \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
``` ```
...@@ -70,7 +70,7 @@ python converter.py \ ...@@ -70,7 +70,7 @@ python converter.py \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
``` ```
......
...@@ -46,7 +46,7 @@ python converter.py \ ...@@ -46,7 +46,7 @@ python converter.py \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name wan_fp8 \ --output_name wan_fp8 \
--dtype torch.float8_e4m3_fn \ --dtype torch.float8_e4m3fn \
--model_type wan_dit --model_type wan_dit
``` ```
...@@ -70,7 +70,7 @@ python converter.py \ ...@@ -70,7 +70,7 @@ python converter.py \
--output /Path/To/output \ --output /Path/To/output \
--output_ext .pth\ --output_ext .pth\
--output_name hunyuan_fp8 \ --output_name hunyuan_fp8 \
--dtype torch.float8_e4m3_fn \ --dtype torch.float8_e4m3fn \
--model_type hunyuan_dit --model_type hunyuan_dit
``` ```
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment