Commit 9b2946b6 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add deprecation warning

parent fe314160
......@@ -2,12 +2,13 @@ import os
import gc
import json
import torch
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from collections import defaultdict
from awq.modules.qlinear import WQLinear_GEMM
from awq.modules.qlinear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset
......@@ -254,7 +255,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True, fuse_layers=False):
safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
......@@ -276,7 +277,7 @@ class BaseAWQForCausalLM(nn.Module):
quant_config = json.loads(file.read())
else:
# Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4}
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
......@@ -294,7 +295,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized
if is_quantized:
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config)
self._load_quantized_modules(self, model, quant_config, version)
model.tie_weights()
......@@ -334,9 +335,14 @@ class BaseAWQForCausalLM(nn.Module):
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, quant_config):
def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now."
if version == 'GEMM':
logging.warning('Deprecated model weight format. Re-quantize '
'your weights again with version="GEMV" for a speedup. '
'In the next AutoAWQ version, GEMM will be deprecated.')
# Get blocks of model
layers = self.get_model_layers(model)
......@@ -352,8 +358,17 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
q_linear = WQLinear_GEMM.from_linear(
module, quant_config['w_bit'], quant_config['q_group_size'], True)
if version == 'GEMM':
q_linear_module = WQLinear_GEMM
elif version == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment