"tests/test_config.py" did not exist on "a2090375ca8e54e7494d73ba31ce91c50659e556"
Commit 90bf52c7 authored by Casper Hansen's avatar Casper Hansen
Browse files

Remove cosine loss

parent 9848b6a4
......@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text", loss_objective='mse'):
split="train", text_column="text"):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
quantizer = AwqQuantizer(
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
quant_config["version"], calib_data, split, text_column, loss_objective
quant_config["version"], calib_data, split, text_column
)
quantizer.quantize()
self.is_quantized = True
......
......@@ -13,7 +13,7 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, loss_objective='mse') -> None:
calib_data, split, text_column) -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
......@@ -24,7 +24,6 @@ class AwqQuantizer:
self.split = split
self.text_column = text_column
self.modules, self.module_kwargs, self.inps = self.init_quant()
self.loss_objective = loss_objective
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
org_w_shape = w.shape
......@@ -191,12 +190,9 @@ class AwqQuantizer:
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
if self.loss_objective == 'mse': # (L2 norm)
# compute mean squared error (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
elif self.loss_objective == 'cosine':
loss = -nn.functional.cosine_similarity(fp16_output, int_w_output).mean().item()
history.append(loss)
if loss < best_error:
best_error = loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment