".github/vscode:/vscode.git/clone" did not exist on "84bbb7140e03df01b3bb388ba4df299328ea2dff"
Commit 84e82744 authored by Casper's avatar Casper
Browse files

Pass split and text_column arguments to calib_data function

parent 077f39a0
import os
import gc
import json
from typing import List, Union
import torch
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from typing import List, Union
from collections import defaultdict
from awq.modules.act import ScaledActivation
......@@ -42,13 +42,17 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, run_search=True, run_quant=True,
calib_data: Union[str, List[str]]="pileval"):
calib_data: Union[str, List[str]]="pileval", split="train",
text_column="text"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
if run_search:
self.search_result = self._awq_search(tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data)
self.search_result = self._awq_search(
tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
split=split, text_column=text_column
)
if run_quant:
self._awq_quant()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment