from pytorch_quantization import nn as quant_nn from pytorch_quantization import quant_modules from pytorch_quantization import calib from tqdm import tqdm def collect_stats(model, data_loader, num_batches, device): # Enable calibrators for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.disable_quant() module.enable_calib() else: module.disable() for i, (image, _) in tqdm(enumerate(data_loader), total=num_batches): model(image.to(device)) if i >= num_batches: break # Disable calibrators for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: module.enable_quant() module.disable_calib() else: module.enable() def compute_amax(model, device, **kwargs): # Load calib result for name, module in model.named_modules(): if isinstance(module, quant_nn.TensorQuantizer): if module._calibrator is not None: if isinstance(module._calibrator, calib.MaxCalibrator): module.load_calib_amax() else: module.load_calib_amax(**kwargs) model.to(device)