Commit 90bf52c7 authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove cosine loss

parent 9848b6a4
...@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval", calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text", loss_objective='mse'): split="train", text_column="text"):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
quantizer = AwqQuantizer( quantizer = AwqQuantizer(
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"], self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
quant_config["version"], calib_data, split, text_column, loss_objective quant_config["version"], calib_data, split, text_column
) )
quantizer.quantize() quantizer.quantize()
self.is_quantized = True self.is_quantized = True
......
...@@ -13,7 +13,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, ...@@ -13,7 +13,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer: class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, loss_objective='mse') -> None: calib_data, split, text_column) -> None:
self.awq_model = awq_model self.awq_model = awq_model
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -24,7 +24,6 @@ class AwqQuantizer: ...@@ -24,7 +24,6 @@ class AwqQuantizer:
self.split = split self.split = split
self.text_column = text_column self.text_column = text_column
self.modules, self.module_kwargs, self.inps = self.init_quant() self.modules, self.module_kwargs, self.inps = self.init_quant()
self.loss_objective = loss_objective
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False): def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
org_w_shape = w.shape org_w_shape = w.shape
...@@ -191,11 +190,8 @@ class AwqQuantizer: ...@@ -191,11 +190,8 @@ class AwqQuantizer:
if isinstance(int_w_output, tuple): if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0] int_w_output = int_w_output[0]
if self.loss_objective == 'mse': # (L2 norm) # compute mean squared error (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
elif self.loss_objective == 'cosine':
loss = -nn.functional.cosine_similarity(fp16_output, int_w_output).mean().item()
history.append(loss) history.append(loss)
if loss < best_error: if loss < best_error:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment