Commit c9d01ac3 authored by Casper Hansen's avatar Casper Hansen
Browse files

Better comments. Implement cosine similarity

parent e205548d
......@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text"):
split="train", text_column="text", loss_objective='mse'):
self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
quantizer = AwqQuantizer(
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
quant_config["version"], calib_data, split, text_column
quant_config["version"], calib_data, split, text_column, loss_objective
)
quantizer.quantize()
self.is_quantized = True
......
......@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, calib_data, split, text_column) -> None:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, loss_objective='mse') -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
......@@ -23,6 +24,7 @@ class AwqQuantizer:
self.split = split
self.text_column = text_column
self.modules, self.module_kwargs, self.inps = self.init_quant()
self.loss_objective = loss_objective
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
org_w_shape = w.shape
......@@ -74,6 +76,10 @@ class AwqQuantizer:
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def _apply_quant(self, module, named_linears: dict[str, nn.Linear]):
for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.cuda().half()
......@@ -101,10 +107,8 @@ class AwqQuantizer:
)
linear_layer.cpu()
q_linear.to(next(self.modules[i].parameters()).device)
set_op_by_name(self.modules[i], name, q_linear)
clear_memory()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)
clear_memory()
@torch.no_grad()
......@@ -133,19 +137,20 @@ class AwqQuantizer:
# [STEP 3]: Compute output of module
with torch.no_grad():
org_out = module2inspect(inp, **kwargs)
if isinstance(org_out, tuple):
org_out = org_out[0]
fp16_output = module2inspect(inp, **kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, org_out, kwargs
layers, fp16_output, kwargs
)
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear], org_out, kwargs={}):
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear],
fp16_output, kwargs={}):
"""
Compute loss and select best scales
......@@ -170,20 +175,29 @@ class AwqQuantizer:
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid
# s^-1
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)
# NOTE: s^-1 * x is fused here, according to paper
for fc in linears2scale:
# Q(W * s)
fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
out = module2inspect(x, **kwargs)
if isinstance(out, tuple):
out = out[0]
# W * X
int_w_output = module2inspect(x, **kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
if self.loss_objective == 'mse': # (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
elif self.loss_objective == 'cosine':
loss = -nn.functional.cosine_similarity(fp16_output, int_w_output).mean().item()
# measure loss and check if better than best
loss = (org_out - out).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss)
if loss < best_error:
best_error = loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment