"vscode:/vscode.git/clone" did not exist on "e5a9c4afbe2fae6e7af70eb14b97aff325c01608"
Commit 356cbc92 authored by Casper's avatar Casper
Browse files

Make code cleaner

parent 9204d6e8
......@@ -206,7 +206,7 @@ class AwqQuantizer:
"""
Compute loss and select best scales
L(s) = ||Q(W \cdot s) (s^{-1} \cdot X) - W \cdot X||
L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Q: weight quantization function | pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X
W: original weights in FP16 | layer
......@@ -219,16 +219,21 @@ class AwqQuantizer:
best_error = float('inf')
org_sd = {k: v.cpu() for k, v in previous_layer.state_dict().items()}
device = x.device
x_max = x_max.view(-1).to(device)
w_max = w_max.view(-1).to(device)
for ratio in range(n_grid):
# create new scales
ratio = ratio * 1 / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4).view(-1)
ratio = ratio / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt()
# multiply scale and quantize
scales_view = scales.view(1, -1).to(device)
for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / (scales.view(1, -1))
fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
out = previous_layer(x, **kwargs)
if isinstance(out, tuple):
......@@ -237,22 +242,21 @@ class AwqQuantizer:
# measure loss and check if better than best
loss = (org_out - out).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss)
is_best = loss < best_error
if is_best:
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales
best_scales = scales.clone()
previous_layer.load_state_dict(org_sd)
if best_ratio == -1:
logging.debug(history)
raise Exception
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach()
def init_quant(self, n_samples=128, seqlen=512):
layers = self.get_model_layers(self.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment