Commit 1712ce21 authored by Erwan BOEHM's avatar Erwan BOEHM
Browse files

allow user to use custom calibration data for quantization

parent abdc726c
import os
import gc
import json
from typing import List, Union
import torch
import functools
import torch.nn as nn
......@@ -39,7 +40,7 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"):
calib_data: Union[str, List[str]]="pileval"):
self.quant_config = quant_config
if run_search:
......@@ -95,7 +96,7 @@ class BaseAWQForCausalLM(nn.Module):
gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"):
auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval"):
layers = self.get_model_layers(self.model)
samples = get_calib_dataset(
......
from typing import List, Union
import torch
import logging
from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
def get_calib_dataset(data: Union[str, List[str]] = "pileval", tokenizer=None, n_samples=512, block_size=512):
if isinstance(data, str):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
raise NotImplementedError
elif isinstance(data, list):
dataset = [{"text": text} for text in data]
else:
raise NotImplementedError
dataset = dataset.shuffle(seed=42)
samples = []
n_run = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment