Commit 1712ce21 authored by Erwan BOEHM's avatar Erwan BOEHM
Browse files

allow user to use custom calibration data for quantization

parent abdc726c
import os import os
import gc import gc
import json import json
from typing import List, Union
import torch import torch
import functools import functools
import torch.nn as nn import torch.nn as nn
...@@ -39,7 +40,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -39,7 +40,7 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True, auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data="pileval"): calib_data: Union[str, List[str]]="pileval"):
self.quant_config = quant_config self.quant_config = quant_config
if run_search: if run_search:
...@@ -95,7 +96,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -95,7 +96,7 @@ class BaseAWQForCausalLM(nn.Module):
gc.collect() gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512, def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data="pileval"): auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval"):
layers = self.get_model_layers(self.model) layers = self.get_model_layers(self.model)
samples = get_calib_dataset( samples = get_calib_dataset(
......
from typing import List, Union
import torch import torch
import logging import logging
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data="pileval", tokenizer=None, n_samples=512, block_size=512): def get_calib_dataset(data: Union[str, List[str]] = "pileval", tokenizer=None, n_samples=512, block_size=512):
if data == "pileval": if isinstance(data, str):
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
raise NotImplementedError
elif isinstance(data, list):
dataset = [{"text": text} for text in data]
else: else:
raise NotImplementedError raise NotImplementedError
dataset = dataset.shuffle(seed=42) dataset = dataset.shuffle(seed=42)
samples = [] samples = []
n_run = 0 n_run = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment