""" This file is part of ComfyUI. Copyright (C) 2024 Stability AI This program is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with this program. If not, see . """ import torch import logging import comfy.model_management from comfy.cli_args import args, PerformanceFeature import comfy.float import comfy.rmsnorm import contextlib try: from lmslim import quant_ops import lmslimquant from lmslim.layers.gemm.int8_utils import per_token_quant_int8 except Exception: print("INFO: Please install lmslim if you want to infergptq or awq or w8a8 model") def scaled_dot_product_attention(q, k, v, *args, **kwargs): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) try: if torch.cuda.is_available(): from torch.nn.attention import SDPBackend, sdpa_kernel import inspect if "set_priority" in inspect.signature(sdpa_kernel).parameters: SDPA_BACKEND_PRIORITY = [ SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION, SDPBackend.MATH, ] SDPA_BACKEND_PRIORITY.insert(0, SDPBackend.CUDNN_ATTENTION) def scaled_dot_product_attention(q, k, v, *args, **kwargs): with sdpa_kernel(SDPA_BACKEND_PRIORITY, set_priority=True): return torch.nn.functional.scaled_dot_product_attention(q, k, v, *args, **kwargs) else: logging.warning("Torch version too old to set sdpa backend priority.") except (ModuleNotFoundError, TypeError): logging.warning("Could not set sdpa backend priority.") cast_to = comfy.model_management.cast_to #TODO: remove once no more references def cast_to_input(weight, input, non_blocking=False, copy=True): return comfy.model_management.cast_to(weight, input.dtype, input.device, non_blocking=non_blocking, copy=copy) def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: if dtype is None: dtype = input.dtype if bias_dtype is None: bias_dtype = dtype if device is None: device = input.device offload_stream = comfy.model_management.get_offload_stream(device) if offload_stream is not None: wf_context = offload_stream else: wf_context = contextlib.nullcontext() bias = None non_blocking = comfy.model_management.device_supports_non_blocking(device) if s.bias is not None: has_function = len(s.bias_function) > 0 bias = comfy.model_management.cast_to(s.bias, bias_dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: with wf_context: for f in s.bias_function: bias = f(bias) has_function = len(s.weight_function) > 0 weight = comfy.model_management.cast_to(s.weight, dtype, device, non_blocking=non_blocking, copy=has_function, stream=offload_stream) if has_function: with wf_context: for f in s.weight_function: weight = f(weight) comfy.model_management.sync_stream(device, offload_stream) return weight, bias class CastWeightBiasOp: comfy_cast_weights = False weight_function = [] bias_function = [] class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv1d(torch.nn.Conv1d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv2d(torch.nn.Conv2d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Conv3d(torch.nn.Conv3d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return self._conv_forward(input, weight, bias) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class GroupNorm(torch.nn.GroupNorm, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): weight, bias = cast_bias_weight(self, input) return torch.nn.functional.group_norm(input, self.num_groups, weight, bias, self.eps) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class LayerNorm(torch.nn.LayerNorm, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): if self.weight is not None: weight, bias = cast_bias_weight(self, input) else: weight = None bias = None return torch.nn.functional.layer_norm(input, self.normalized_shape, weight, bias, self.eps) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class RMSNorm(comfy.rmsnorm.RMSNorm, CastWeightBiasOp): def reset_parameters(self): self.bias = None return None def forward_comfy_cast_weights(self, input): if self.weight is not None: weight, bias = cast_bias_weight(self, input) else: weight = None return comfy.rmsnorm.rms_norm(input, weight, self.eps) # TODO: switch to commented out line when old torch is deprecated # return torch.nn.functional.rms_norm(input, self.normalized_shape, weight, self.eps) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class ConvTranspose2d(torch.nn.ConvTranspose2d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input, output_size=None): num_spatial_dims = 2 output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.conv_transpose2d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class ConvTranspose1d(torch.nn.ConvTranspose1d, CastWeightBiasOp): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input, output_size=None): num_spatial_dims = 1 output_padding = self._output_padding( input, output_size, self.stride, self.padding, self.kernel_size, num_spatial_dims, self.dilation) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.conv_transpose1d( input, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: return super().forward(*args, **kwargs) class Embedding(torch.nn.Embedding, CastWeightBiasOp): def reset_parameters(self): self.bias = None return None def forward_comfy_cast_weights(self, input, out_dtype=None): output_dtype = out_dtype if self.weight.dtype == torch.float16 or self.weight.dtype == torch.bfloat16: out_dtype = None weight, bias = cast_bias_weight(self, device=input.device, dtype=out_dtype) return torch.nn.functional.embedding(input, weight, self.padding_idx, self.max_norm, self.norm_type, self.scale_grad_by_freq, self.sparse).to(dtype=output_dtype) def forward(self, *args, **kwargs): if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0: return self.forward_comfy_cast_weights(*args, **kwargs) else: if "out_dtype" in kwargs: kwargs.pop("out_dtype") return super().forward(*args, **kwargs) @classmethod def conv_nd(s, dims, *args, **kwargs): if dims == 2: return s.Conv2d(*args, **kwargs) elif dims == 3: return s.Conv3d(*args, **kwargs) else: raise ValueError(f"unsupported dimensions: {dims}") class manual_cast(disable_weight_init): class Linear(disable_weight_init.Linear): comfy_cast_weights = True class Conv1d(disable_weight_init.Conv1d): comfy_cast_weights = True class Conv2d(disable_weight_init.Conv2d): comfy_cast_weights = True class Conv3d(disable_weight_init.Conv3d): comfy_cast_weights = True class GroupNorm(disable_weight_init.GroupNorm): comfy_cast_weights = True class LayerNorm(disable_weight_init.LayerNorm): comfy_cast_weights = True class ConvTranspose2d(disable_weight_init.ConvTranspose2d): comfy_cast_weights = True class ConvTranspose1d(disable_weight_init.ConvTranspose1d): comfy_cast_weights = True class RMSNorm(disable_weight_init.RMSNorm): comfy_cast_weights = True class Embedding(disable_weight_init.Embedding): comfy_cast_weights = True from typing import Optional class manual_cast_int8_per_channel(manual_cast): class Linear(torch.nn.Module): def __init__(self, in_features, out_features, bias=True, dtype=None, device=None): super().__init__() self.in_features = in_features self.out_features = out_features self.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype, device=device), requires_grad=False) if bias: self.bias = torch.nn.Parameter(torch.empty(out_features, dtype=dtype, device=device)) else: self.register_parameter("bias", None) self.weight_quant = None self.weight_scale = None def blaslt_scaled_mm(self, a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias: Optional[torch.Tensor] = None) -> torch.Tensor: m = a.shape[0] n = b.shape[0] k = a.shape[1] _, out = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a.to(torch.float32), scale_b.to(torch.float32), m, n, k, 'NT', out_dtype) if bias is not None: out += bias return out def weight_quant_int8(self, weight): org_w_shape = weight.shape w = weight.to(torch.bfloat16) max_val = w.abs().amax(dim=1, keepdim=True).clamp(min=1e-5) qmin, qmax = -128, 127 scales = (max_val / qmax).float() w_q = torch.clamp(torch.round(w / scales), qmin, qmax).to(torch.int8) assert torch.isnan(scales).sum() == 0 assert torch.isnan(w_q).sum() == 0 scales = scales.view(org_w_shape[0], -1) w_q = w_q.reshape(org_w_shape) return w_q, scales def forward(self, input): #return self.forward_calibration(input) dim = input.dim() if dim > 2: input = input.squeeze(0) if self.weight_quant is None: self.weight_quant, self.weight_scale = self.weight_quant_int8(self.weight) self.bias = torch.nn.Parameter(self.bias.to(input.dtype)) input_quant, input_scale = per_token_quant_int8(input) output_tensor = self.blaslt_scaled_mm(input_quant, self.weight_quant, input_scale, self.weight_scale, input.dtype, self.bias) if dim > 2: output_tensor = output_tensor.unsqueeze(0) return output_tensor class manual_cast_int8(manual_cast): class Linear(torch.nn.Module, CastWeightBiasOp): __constants__ = ['in_features', 'out_features'] in_features: int out_features: int weight: torch.Tensor def __init__(self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() print("=============use int8==============") self.in_features = in_features self.out_features = out_features # self.weight = Parameter(torch.empty((out_features, in_features),dtype=torch.int8, device=device)) # self.weight_scale = Parameter(torch.empty((out_features,1), **factory_kwargs)) self.register_buffer("weight", torch.empty((out_features, in_features), dtype=torch.int8, device=device)) self.register_buffer("weight_scale", torch.empty((out_features, 1), dtype=torch.float16, device=device)) if bias: self.bias = torch.nn.Parameter(torch.empty(out_features,dtype=torch.float16, device=device)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self) -> None: return None def verify_quant_gemm(self,input_q,weight_q,input_scale, weight_scale,out_dtype: torch.dtype, bias): # 2. INT GEMM # (int8 matmul -> cast to int32 accumulated result) y_q = (input_q.cpu().int() @ (weight_q.cpu().int().t())) # 3. Dequantize y_deq = y_q * ((input_scale * weight_scale.t()).cpu()) # 4. Reference FP32 GEMM return y_deq.to(out_dtype).cuda() def blaslt_scaled_mm(self, a: torch.Tensor, b: torch.Tensor, scale_a: torch.Tensor, scale_b: torch.Tensor, out_dtype: torch.dtype, bias) -> torch.Tensor: # b = b.t() m = a.shape[0] n = b.shape[0] k = a.shape[1] # import pdb # pdb.set_trace() stat, output = quant_ops.hipblaslt_w8a8_gemm(a, b, scale_a, scale_b, m, n, k, 'NT', out_dtype) # output = matmul_int8(a, scale_a, b, scale_b, out_dtype, config=None) # status, output = torch.ops.lmslim.lightop_channel_int8_mm(a, b, scale_a, scale_b, out_dtype, bias) if bias is not None: output += bias # torch.cuda.synchronize() # out = torch.rand((m, n),dtype=torch.bfloat16, device=a.device) return output def quantize_symmetric_per_row_int8(self, x: torch.Tensor): """ 对输入 x 进行 per-row(dim=1)对称 INT8 量化。 Args: x: tensor of shape [B, N], dtype in {float32, float16, bfloat16} Returns: x_q: quantized int8 tensor, shape [B, N] scales: scale per row, shape [B, 1], same dtype as x """ assert x.ndim == 2, f"Expected 2D input, got {x.shape}" assert x.dtype in [torch.float32, torch.float16, torch.bfloat16] # Step 1: 计算每行的最大绝对值 -> shape [B, 1] max_abs = x.abs().amax(dim=1, keepdim=True) # keepdim=True 保证 shape [32, 1] # Step 2: 计算 scale = max_abs / 127 # 避免除零:若某行为全零,则 scale=1 scales = torch.where( max_abs == 0, torch.tensor(1.0, dtype=x.dtype, device=x.device), max_abs / 127.0 ) # shape [32, 1], dtype = x.dtype # Step 3: 量化:x_q = round(x / scales) # 为避免 bfloat16 精度问题,中间计算用 float32 x_f32 = x.to(torch.float32) scales_f32 = scales.to(torch.float32) x_q_f32 = torch.round(x_f32 / scales_f32) # Step 4: clamp 到 [-127, 127] 并转为 int8 x_q = torch.clamp(x_q_f32, -127, 127).to(torch.int8) return x_q, scales_f32 def forward(self, input_tensor: torch.Tensor): # import pdb # pdb.set_trace() dim = input_tensor.dim() if dim > 2: input_tensor = input_tensor.squeeze(0) dtype = input_tensor.dtype # print # import pdb # pdb.set_trace() input_tensor_quant, input_tensor_scale = per_token_quant_int8(input_tensor) # input_tensor_quant, input_tensor_scale = self.quantize_symmetric_per_row_int8(input_tensor) output_tensor = self.blaslt_scaled_mm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.to(torch.float32), dtype, self.bias) # output_sf = self.verify_quant_gemm(input_tensor_quant, self.weight, input_tensor_scale, self.weight_scale.to(torch.float32), dtype, self.bias) if dim > 2: output_tensor = output_tensor.unsqueeze(0) return output_tensor def extra_repr(self) -> str: return f'in_features={self.in_features}, out_features={self.out_features}, bias={self.bias is not None}' def fp8_linear(self, input): dtype = self.weight.dtype if dtype not in [torch.float8_e4m3fn]: return None tensor_2d = False if len(input.shape) == 2: tensor_2d = True input = input.unsqueeze(1) input_shape = input.shape input_dtype = input.dtype if len(input.shape) == 3: w, bias = cast_bias_weight(self, input, dtype=dtype, bias_dtype=input_dtype) w = w.t() scale_weight = self.scale_weight scale_input = self.scale_input if scale_weight is None: scale_weight = torch.ones((), device=input.device, dtype=torch.float32) else: scale_weight = scale_weight.to(input.device) if scale_input is None: scale_input = torch.ones((), device=input.device, dtype=torch.float32) input = torch.clamp(input, min=-448, max=448, out=input) input = input.reshape(-1, input_shape[2]).to(dtype).contiguous() else: scale_input = scale_input.to(input.device) input = (input * (1.0 / scale_input).to(input_dtype)).reshape(-1, input_shape[2]).to(dtype).contiguous() if bias is not None: o = torch._scaled_mm(input, w, out_dtype=input_dtype, bias=bias, scale_a=scale_input, scale_b=scale_weight) else: o = torch._scaled_mm(input, w, out_dtype=input_dtype, scale_a=scale_input, scale_b=scale_weight) if isinstance(o, tuple): o = o[0] if tensor_2d: return o.reshape(input_shape[0], -1) return o.reshape((-1, input_shape[1], self.weight.shape[0])) return None class fp8_ops(manual_cast): class Linear(manual_cast.Linear): def reset_parameters(self): self.scale_weight = None self.scale_input = None return None def forward_comfy_cast_weights(self, input): try: out = fp8_linear(self, input) if out is not None: return out except Exception as e: logging.info("Exception during fp8 op: {}".format(e)) weight, bias = cast_bias_weight(self, input) return torch.nn.functional.linear(input, weight, bias) def scaled_fp8_ops(fp8_matrix_mult=False, scale_input=False, override_dtype=None): logging.info("Using scaled fp8: fp8 matrix mult: {}, scale input: {}".format(fp8_matrix_mult, scale_input)) class scaled_fp8_op(manual_cast): class Linear(manual_cast.Linear): def __init__(self, *args, **kwargs): if override_dtype is not None: kwargs['dtype'] = override_dtype super().__init__(*args, **kwargs) def reset_parameters(self): if not hasattr(self, 'scale_weight'): self.scale_weight = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) if not scale_input: self.scale_input = None if not hasattr(self, 'scale_input'): self.scale_input = torch.nn.parameter.Parameter(data=torch.ones((), device=self.weight.device, dtype=torch.float32), requires_grad=False) return None def forward_comfy_cast_weights(self, input): if fp8_matrix_mult: out = fp8_linear(self, input) if out is not None: return out weight, bias = cast_bias_weight(self, input) if weight.numel() < input.numel(): #TODO: optimize return torch.nn.functional.linear(input, weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype), bias) else: return torch.nn.functional.linear(input * self.scale_weight.to(device=weight.device, dtype=weight.dtype), weight, bias) def convert_weight(self, weight, inplace=False, **kwargs): if inplace: weight *= self.scale_weight.to(device=weight.device, dtype=weight.dtype) return weight else: return weight * self.scale_weight.to(device=weight.device, dtype=weight.dtype) def set_weight(self, weight, inplace_update=False, seed=None, **kwargs): weight = comfy.float.stochastic_rounding(weight / self.scale_weight.to(device=weight.device, dtype=weight.dtype), self.weight.dtype, seed=seed) if inplace_update: self.weight.data.copy_(weight) else: self.weight = torch.nn.Parameter(weight, requires_grad=False) return scaled_fp8_op CUBLAS_IS_AVAILABLE = False try: from cublas_ops import CublasLinear CUBLAS_IS_AVAILABLE = True except ImportError: pass if CUBLAS_IS_AVAILABLE: class cublas_ops(disable_weight_init): class Linear(CublasLinear, disable_weight_init.Linear): def reset_parameters(self): return None def forward_comfy_cast_weights(self, input): return super().forward(input) def forward(self, *args, **kwargs): return super().forward(*args, **kwargs) def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, scaled_fp8=None, int8_optimizations=None): if int8_optimizations is not None and int8_optimizations: return manual_cast_int8_per_channel fp8_compute = comfy.model_management.supports_fp8_compute(load_device) if scaled_fp8 is not None: return scaled_fp8_ops(fp8_matrix_mult=fp8_compute and fp8_optimizations, scale_input=fp8_optimizations, override_dtype=scaled_fp8) if ( fp8_compute and (fp8_optimizations or PerformanceFeature.Fp8MatrixMultiplication in args.fast) and not disable_fast_fp8 ): return fp8_ops if ( PerformanceFeature.CublasOps in args.fast and CUBLAS_IS_AVAILABLE and weight_dtype == torch.float16 and (compute_dtype == torch.float16 or compute_dtype is None) ): logging.info("Using cublas ops") return cublas_ops if compute_dtype is None or weight_dtype == compute_dtype: return disable_weight_init return manual_cast