Unverified Commit 2a3e0fa1 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #27 from boehm-e/main

Allow user to use custom calibration data for quantization
parents d76125bf b9ab9a64
......@@ -6,6 +6,7 @@ import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from typing import List, Union
from collections import defaultdict
from awq.modules.act import ScaledActivation
......@@ -41,13 +42,17 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"):
calib_data: Union[str, List[str]]="pileval", split="train",
text_column="text"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
self.search_result = self._awq_search(
tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
split=split, text_column=text_column
)
if run_quant:
self._awq_quant()
......@@ -103,11 +108,14 @@ class BaseAWQForCausalLM(nn.Module):
gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
split="train", text_column="text"):
layers = self.get_model_layers(self.model)
samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen,
split=split, text_column=text_column
)
samples = torch.cat(samples, dim=0)
inps = []
......
import torch
import logging
from typing import List, Union
from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
def get_calib_dataset(data: Union[str, List[str]] = "pileval",
tokenizer=None, n_samples=512, block_size=512,
split="train", text_column="text"):
if isinstance(data, str):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
raise NotImplementedError
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list):
dataset = [{text_column: text} for text in data]
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element.")
samples = []
n_run = 0
for data in dataset:
line = data["text"]
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
......
from datasets import load_dataset
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Define data loading methods
def load_dolly():
data = load_dataset('databricks/databricks-dolly-15k', split="train")
# concatenate data
def concatenate_data(x):
return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']}
concatenated = data.map(concatenate_data)
return [text for text in concatenated["text"]]
def load_wikitext():
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")
return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20]
# Quantize
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext())
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment