import torch from abc import ABCMeta, abstractmethod from vllm import _custom_ops as ops from lightx2v.utils.registry_factory import MM_WEIGHT_REGISTER from lightx2v.utils.quant_utils import IntegerQuantizer, FloatQuantizer class MMWeightTemplate(metaclass=ABCMeta): def __init__(self, weight_name, bias_name): self.weight_name = weight_name self.bias_name = bias_name self.config = {} @abstractmethod def load(self, weight_dict): pass @abstractmethod def apply(self, input_tensor): pass def set_config(self, config=None): if config is not None: self.config = config @MM_WEIGHT_REGISTER('Default') class MMWeight(MMWeightTemplate): def __init__(self, weight_name, bias_name): super().__init__(weight_name, bias_name) def load(self, weight_dict): self.weight = weight_dict[self.weight_name].t().cuda() self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None def apply(self, input_tensor): shape = (input_tensor.shape[0], self.weight.shape[1]) dtype = input_tensor.dtype device = input_tensor.device output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) if self.bias is None: return torch.mm(input_tensor, self.weight, out=output_tensor) return torch.addmm(self.bias, input_tensor, self.weight, out=output_tensor) def to_cpu(self): self.weight = self.weight.cpu() if self.bias is not None: self.bias = self.bias.cpu() def to_cuda(self): self.weight = self.weight.cuda() if self.bias is not None: self.bias = self.bias.cuda() @MM_WEIGHT_REGISTER('Default-Force-FP32') class MMWeight(MMWeight): def __init__(self, weight_name, bias_name): super().__init__(weight_name, bias_name) def load(self, weight_dict): super().load(weight_dict) self.weight = self.weight.to(torch.float32) if self.bias is not None: self.bias = self.bias.to(torch.float32) @MM_WEIGHT_REGISTER('W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm') class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate): ''' Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm Quant MM: Weight: fp8 perchannel sym Act: fp8 perchannel dynamic sym Kernel: vllm ''' def __init__(self, weight_name, bias_name): super().__init__(weight_name, bias_name) def load(self, weight_dict): if self.config.get('weight_auto_quant', True): self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() w_quantizer = FloatQuantizer('e4m3', True, 'channel') self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight = self.weight.to(torch.float8_e4m3fn).t().cuda() self.weight_scale = self.weight_scale.to(torch.float32).cuda() else: self.weight = weight_dict[self.weight_name].t().cuda() self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda() self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None def apply(self, input_tensor): shape = (input_tensor.shape[0], self.weight.shape[1]) dtype = input_tensor.dtype device = input_tensor.device output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) qinput, x_scale = ops.scaled_fp8_quant(input_tensor, None, scale_ub=None, use_per_token_if_dynamic=True) torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias) return output_tensor def to_cpu(self): self.weight = self.weight.cpu() self.weight_scale = self.weight_scale.cpu() if self.bias is not None: self.bias = self.bias.cpu() def to_cuda(self): self.weight = self.weight.cuda() self.weight_scale = self.weight_scale.cuda() if self.bias is not None: self.bias = self.bias.cuda() @MM_WEIGHT_REGISTER('W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm') class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate): ''' Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm Quant MM: Weight: int8 perchannel sym Act: int8 perchannel dynamic sym Kernel: vllm ''' def __init__(self, weight_name, bias_name): super().__init__(weight_name, bias_name) def load(self, weight_dict): if self.config.get('weight_auto_quant', True): self.weight = weight_dict[self.weight_name].to(torch.float32).cuda() w_quantizer = IntegerQuantizer(8, True, 'channel') self.weight, self.weight_scale, _ = w_quantizer.real_quant_tensor(self.weight) self.weight = self.weight.to(torch.int8).t().cuda() self.weight_scale = self.weight_scale.to(torch.float32).cuda() else: self.weight = weight_dict[self.weight_name].t().cuda() self.weight_scale = weight_dict[self.weight_name.rstrip(".weight") + '.weight_scale'].cuda() self.bias = weight_dict[self.bias_name].cuda() if self.bias_name is not None else None def apply(self, input_tensor): shape = (input_tensor.shape[0], self.weight.shape[1]) dtype = input_tensor.dtype device = input_tensor.device output_tensor = torch.empty(shape, dtype=dtype, device=device, requires_grad=False) qinput, x_scale, _ = ops.scaled_int8_quant(input_tensor, scale=None, azp=None, symmetric=True) torch.ops._C.cutlass_scaled_mm(output_tensor, qinput, self.weight, x_scale, self.weight_scale, self.bias) return output_tensor def to_cpu(self): self.weight = self.weight.cpu() self.weight_scale = self.weight_scale.cpu() if self.bias is not None: self.bias = self.bias.cpu() def to_cuda(self): self.weight = self.weight.cuda() self.weight_scale = self.weight_scale.cuda() if self.bias is not None: self.bias = self.bias.cuda() if __name__ == '__main__': weight_dict = { 'xx.weight': torch.randn(8192, 4096).to(torch.float8_e4m3fn), 'xx.bias': torch.randn(8192).to(torch.bfloat16), 'xx.weight_scale': torch.randn(8192, 1).to(torch.float32), } mm_weight = MM_WEIGHT_REGISTER['W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias') mm_weight.set_config({'weight_auto_quant': False}) mm_weight.load(weight_dict) input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() output_tensor = mm_weight.apply(input_tensor) print(output_tensor.shape) weight_dict = { 'xx.weight': torch.randn(8192, 4096), 'xx.bias': torch.randn(8192).to(torch.bfloat16), } mm_weight = MM_WEIGHT_REGISTER['W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias') mm_weight.set_config({'weight_auto_quant': True}) mm_weight.load(weight_dict) input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() output_tensor = mm_weight.apply(input_tensor) print(output_tensor.shape) weight_dict = { 'xx.weight': torch.randn(8192, 4096), 'xx.bias': torch.randn(8192).to(torch.bfloat16), } mm_weight = MM_WEIGHT_REGISTER['W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm']('xx.weight', 'xx.bias') mm_weight.set_config({'weight_auto_quant': True}) mm_weight.load(weight_dict) input_tensor = torch.randn(1024, 4096).to(torch.bfloat16).cuda() output_tensor = mm_weight.apply(input_tensor) print(output_tensor.shape)