Commit 9b2946b6 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add deprecation warning

parent fe314160
...@@ -2,12 +2,13 @@ import os ...@@ -2,12 +2,13 @@ import os
import gc import gc
import json import json
import torch import torch
import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from collections import defaultdict from collections import defaultdict
from awq.modules.qlinear import WQLinear_GEMM from awq.modules.qlinear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
...@@ -254,7 +255,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -254,7 +255,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None, def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True, fuse_layers=False): safetensors=False, is_quantized=True, fuse_layers=False, version='GEMM'):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
...@@ -276,7 +277,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -276,7 +277,7 @@ class BaseAWQForCausalLM(nn.Module):
quant_config = json.loads(file.read()) quant_config = json.loads(file.read())
else: else:
# Default config that works for most models # Default config that works for most models
quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4} quant_config = {"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM"}
# Load model config and set max generation length # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'): if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
...@@ -294,7 +295,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -294,7 +295,7 @@ class BaseAWQForCausalLM(nn.Module):
# Only need to replace layers if a model is AWQ quantized # Only need to replace layers if a model is AWQ quantized
if is_quantized: if is_quantized:
# Prepare WQLinear layers, replace nn.Linear # Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config) self._load_quantized_modules(self, model, quant_config, version)
model.tie_weights() model.tie_weights()
...@@ -334,9 +335,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -334,9 +335,14 @@ class BaseAWQForCausalLM(nn.Module):
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_quantized_modules(self, model, quant_config): def _load_quantized_modules(self, model, quant_config, version):
# Real quantization of weights # Real quantization of weights
assert quant_config["zero_point"], "We only support zero_point quantization now." assert quant_config["zero_point"], "We only support zero_point quantization now."
if version == 'GEMM':
logging.warning('Deprecated model weight format. Re-quantize '
'your weights again with version="GEMV" for a speedup. '
'In the next AutoAWQ version, GEMM will be deprecated.')
# Get blocks of model # Get blocks of model
layers = self.get_model_layers(model) layers = self.get_model_layers(model)
...@@ -352,8 +358,17 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -352,8 +358,17 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear # Replace nn.Linear with WQLinear
for name, module in named_linears.items(): for name, module in named_linears.items():
q_linear = WQLinear_GEMM.from_linear( if version == 'GEMM':
module, quant_config['w_bit'], quant_config['q_group_size'], True) q_linear_module = WQLinear_GEMM
elif version == 'GEMV':
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
quant_config['w_bit'],
quant_config['q_group_size'],
True
)
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment