Commit a9cef34b authored by Casper's avatar Casper
Browse files

Actually pass args to calib function

parent 69d31edc
......@@ -108,11 +108,14 @@ class BaseAWQForCausalLM(nn.Module):
gc.collect()
def _awq_search(self, tokenizer, quant_config, n_samples=128, seqlen=512,
auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval"):
auto_scale=True, mse_range=True, calib_data:Union[str, List[str]]="pileval",
split="train", text_column="text"):
layers = self.get_model_layers(self.model)
samples = get_calib_dataset(
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen)
data=calib_data, tokenizer=tokenizer, n_samples=n_samples, block_size=seqlen,
split=split, text_column=text_column
)
samples = torch.cat(samples, dim=0)
inps = []
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment