"git@developer.sourcefind.cn:OpenDAS/vision.git" did not exist on "8faaefc41af0b4fdbe0543d455390acc4f4c7710"
Commit a9cef34b authored by Casper's avatar Casper
Browse files

Actually pass args to calib function

parent 69d31edc
...@@ -108,11 +108,14 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -108,11 +108,14 @@ class BaseAWQForCausalLM(nn.Module):
gc.collect() gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512, def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval"): auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
split="train", text_column="text"):
layers = self.get_model_layers(self.model) layers = self.get_model_layers(self.model)
samples = get_calib_dataset( samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen) data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen,
split=split, text_column=text_column
)
samples = torch.cat(samples, dim=0) samples = torch.cat(samples, dim=0)
inps = [] inps = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment