import torch from .mm_weight import MMWeight from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer @MM_WEIGHT_REGISTER('Calib') class MMWeightCalib(MMWeight): def __init__(self, weight_name, bias_name): super().__init__(weight_name, bias_name) def load(self, weight_dict): assert self.config and self.config.get('mm_type', 'Default') != 'Default' self.weight = weight_dict[self.weight_name] self.get_quantizer() shape_and_dtype = self.get_quant_shape_and_dtype(self.weight.shape) self.realq_weight, self.scales, self.zeros = self.w_quantizer.real_quant_tensor(self.weight) self.realq_weight = self.realq_weight.view(shape_and_dtype['tensor'][0]).contiguous().to(shape_and_dtype['tensor'][1]) self.scales = self.scales.view(shape_and_dtype['scales'][0]).contiguous().to(shape_and_dtype['scales'][1]) if self.zeros is not None: self.zeros = self.zeros.view(shape_and_dtype['zeros'][0]).contiguous().to(shape_and_dtype['zeros'][1]) def apply(self, input_tensor): return super().apply(input_tensor) def get_quantizer(self): if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm': self.w_setting = { 'bit': 'e4m3', 'symmetric': True, 'granularity': 'channel' } self.a_setting = { 'bit': 'e4m3', 'symmetric': True, 'granularity': 'channel' } self.w_quantizer = FloatQuantizer(**self.w_setting) self.a_quantizer = FloatQuantizer(**self.a_setting) self.act_dynamic_quant = True else: raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}') def get_quant_shape_and_dtype(self, shape): if self.config['mm_type'] == 'W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm': return { 'tensor': (shape, torch.float8_e5m2), 'scales': ((shape[0], 1), torch.float32), 'zeros': None, } else: raise NotImplementedError(f'Unsupported mm_type: {self.config["mm_type"]}')