Unverified Commit 2a3e0fa1 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #27 from boehm-e/main

Allow user to use custom calibration data for quantization
parents d76125bf b9ab9a64
...@@ -6,6 +6,7 @@ import logging ...@@ -6,6 +6,7 @@ import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union
from collections import defaultdict from collections import defaultdict
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
...@@ -41,13 +42,17 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -41,13 +42,17 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True, auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"): calib_data: Union[str, List[str]]="pileval", split="train",
text_column="text"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search: if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen, self.search_result = self._awq_search(
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data) tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
split=split, text_column=text_column
)
if run_quant: if run_quant:
self._awq_quant() self._awq_quant()
...@@ -103,11 +108,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -103,11 +108,14 @@ class BaseAWQForCausalLM(nn.Module):
gc.collect() gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512, def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"): auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
split="train", text_column="text"):
layers = self.get_model_layers(self.model) layers = self.get_model_layers(self.model)
samples = get_calib_dataset( samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen) data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen,
split=split, text_column=text_column
)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
inps = [] inps = []
......
import torch import torch
import logging import logging
from typing import List, Union
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): def get_calib_dataset(data: Union[str, List[str]] = "pileval",
if data == "pileval": tokenizer=None, n_samples=512, block_size=512,
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") split="train", text_column="text"):
if isinstance(data, str):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list):
dataset = [{text_column: text} for text in data]
else: else:
raise NotImplementedError raise NotImplementedError(
dataset = dataset.shuffle(seed=42) "Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element.")
samples = [] samples = []
n_run = 0 n_run = 0
for data in dataset: for data in dataset:
line = data["text"] line = data[text_column]
line = line.strip() line = line.strip()
line_encoded = tokenizer.encode(line) line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512: if len(line_encoded) > 512:
......
from datasets import load_dataset
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq'
quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM" }
# Load model
model = AutoAWQForCausalLM.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Define data loading methods
def load_dolly():
data = load_dataset('databricks/databricks-dolly-15k', split="train")
# concatenate data
def concatenate_data(x):
return {"text": x['instruction'] + '\n' + x['context'] + '\n' + x['response']}
concatenated = data.map(concatenate_data)
return [text for text in concatenated["text"]]
def load_wikitext():
data = load_dataset('wikitext', 'wikitext-2-raw-v1', split="train")
return [text for text in data["text"] if text.strip() != '' and len(text.split(' ')) > 20]
# Quantize
model.quantize(tokenizer, quant_config=quant_config, calib_data=load_wikitext())
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment