from typing import TYPE_CHECKING, Any, Dict, List, Union from ...utils import ( get_module_from_name, is_accelerate_available, is_nvidia_modelopt_available, is_torch_available, logging, ) from ..base import DiffusersQuantizer if TYPE_CHECKING: from ...models.modeling_utils import ModelMixin if is_torch_available(): import torch import torch.nn as nn if is_accelerate_available(): from accelerate.utils import set_module_tensor_to_device logger = logging.get_logger(__name__) class NVIDIAModelOptQuantizer(DiffusersQuantizer): r""" Diffusers Quantizer for TensorRT Model Optimizer """ use_keep_in_fp32_modules = True requires_calibration = False required_packages = ["nvidia_modelopt"] def __init__(self, quantization_config, **kwargs): super().__init__(quantization_config, **kwargs) def validate_environment(self, *args, **kwargs): if not is_nvidia_modelopt_available(): raise ImportError( "Loading an nvidia-modelopt quantized model requires nvidia-modelopt library (`pip install nvidia-modelopt`)" ) self.offload = False device_map = kwargs.get("device_map", None) if isinstance(device_map, dict): if "cpu" in device_map.values() or "disk" in device_map.values(): if self.pre_quantized: raise ValueError( "You are attempting to perform cpu/disk offload with a pre-quantized modelopt model " "This is not supported yet. Please remove the CPU or disk device from the `device_map` argument." ) else: self.offload = True def check_if_quantized_param( self, model: "ModelMixin", param_value: "torch.Tensor", param_name: str, state_dict: Dict[str, Any], **kwargs, ): # ModelOpt imports diffusers internally. This is here to prevent circular imports from modelopt.torch.quantization.utils import is_quantized module, tensor_name = get_module_from_name(model, param_name) if self.pre_quantized: return True elif is_quantized(module) and "weight" in tensor_name: return True return False def create_quantized_param( self, model: "ModelMixin", param_value: "torch.Tensor", param_name: str, target_device: "torch.device", *args, **kwargs, ): """ Create the quantized parameter by calling .calibrate() after setting it to the module. """ # ModelOpt imports diffusers internally. This is here to prevent circular imports import modelopt.torch.quantization as mtq dtype = kwargs.get("dtype", torch.float32) module, tensor_name = get_module_from_name(model, param_name) if self.pre_quantized: module._parameters[tensor_name] = torch.nn.Parameter(param_value.to(device=target_device)) else: set_module_tensor_to_device(model, param_name, target_device, param_value, dtype) mtq.calibrate( module, self.quantization_config.modelopt_config["algorithm"], self.quantization_config.forward_loop ) mtq.compress(module) module.weight.requires_grad = False def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]: max_memory = {key: val * 0.90 for key, val in max_memory.items()} return max_memory def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype": if self.quantization_config.quant_type == "FP8": target_dtype = torch.float8_e4m3fn return target_dtype def update_torch_dtype(self, torch_dtype: "torch.dtype" = None) -> "torch.dtype": if torch_dtype is None: logger.info("You did not specify `torch_dtype` in `from_pretrained`. Setting it to `torch.float32`.") torch_dtype = torch.float32 return torch_dtype def get_conv_param_names(self, model: "ModelMixin") -> List[str]: """ Get parameter names for all convolutional layers in a HuggingFace ModelMixin. Includes Conv1d/2d/3d and ConvTranspose1d/2d/3d. """ conv_types = ( nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d, ) conv_param_names = [] for name, module in model.named_modules(): if isinstance(module, conv_types): for param_name, _ in module.named_parameters(recurse=False): conv_param_names.append(f"{name}.{param_name}") return conv_param_names def _process_model_before_weight_loading( self, model: "ModelMixin", device_map, keep_in_fp32_modules: List[str] = [], **kwargs, ): # ModelOpt imports diffusers internally. This is here to prevent circular imports import modelopt.torch.opt as mto if self.pre_quantized: return modules_to_not_convert = self.quantization_config.modules_to_not_convert if modules_to_not_convert is None: modules_to_not_convert = [] if isinstance(modules_to_not_convert, str): modules_to_not_convert = [modules_to_not_convert] modules_to_not_convert.extend(keep_in_fp32_modules) if self.quantization_config.disable_conv_quantization: modules_to_not_convert.extend(self.get_conv_param_names(model)) for module in modules_to_not_convert: self.quantization_config.modelopt_config["quant_cfg"]["*" + module + "*"] = {"enable": False} self.quantization_config.modules_to_not_convert = modules_to_not_convert mto.apply_mode(model, mode=[("quantize", self.quantization_config.modelopt_config)]) model.config.quantization_config = self.quantization_config def _process_model_after_weight_loading(self, model, **kwargs): # ModelOpt imports diffusers internally. This is here to prevent circular imports from modelopt.torch.opt import ModeloptStateManager if self.pre_quantized: return model for _, m in model.named_modules(): if hasattr(m, ModeloptStateManager._state_key) and m is not model: ModeloptStateManager.remove_state(m) return model @property def is_trainable(self): return True @property def is_serializable(self): self.quantization_config.check_model_patching(operation="saving") return True