Commit 84e82744 authored by Casper's avatar Casper
Browse files

Pass split and text_column arguments to calib_data function

parent 077f39a0
import os import os
import gc import gc
import json import json
from typing import List, Union
import torch import torch
import logging import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
from tqdm import tqdm from tqdm import tqdm
from typing import List, Union
from collections import defaultdict from collections import defaultdict
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
...@@ -42,13 +42,17 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -42,13 +42,17 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512, def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True, auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data: Union[str, List[str]]="pileval"): calib_data: Union[str, List[str]]="pileval", split="train",
text_column="text"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search: if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen, self.search_result = self._awq_search(
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data) tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
split=split, text_column=text_column
)
if run_quant: if run_quant:
self._awq_quant() self._awq_quant()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment