Commit 356cbc92 authored by Casper's avatar Casper
Browse files

Make code cleaner

parent 9204d6e8
...@@ -206,7 +206,7 @@ class AwqQuantizer: ...@@ -206,7 +206,7 @@ class AwqQuantizer:
""" """
Compute loss and select best scales Compute loss and select best scales
L(s) = ||Q(W \cdot s) (s^{-1} \cdot X) - W \cdot X|| L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Q: weight quantization function | pseudo_quantize_tensor(W * s) Q: weight quantization function | pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X X: inputs from calib dataset | X
W: original weights in FP16 | layer W: original weights in FP16 | layer
...@@ -219,16 +219,21 @@ class AwqQuantizer: ...@@ -219,16 +219,21 @@ class AwqQuantizer:
best_error = float('inf') best_error = float('inf')
org_sd = {k: v.cpu() for k, v in previous_layer.state_dict().items()} org_sd = {k: v.cpu() for k, v in previous_layer.state_dict().items()}
device = x.device
x_max = x_max.view(-1).to(device)
w_max = w_max.view(-1).to(device)
for ratio in range(n_grid): for ratio in range(n_grid):
# create new scales # create new scales
ratio = ratio * 1 / n_grid ratio = ratio / n_grid
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4).view(-1) scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
# multiply scale and quantize scales_view = scales.view(1, -1).to(device)
for fc in linears2scale: for fc in linears2scale:
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / (scales.view(1, -1)) fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
out = previous_layer(x, **kwargs) out = previous_layer(x, **kwargs)
if isinstance(out, tuple): if isinstance(out, tuple):
...@@ -237,22 +242,21 @@ class AwqQuantizer: ...@@ -237,22 +242,21 @@ class AwqQuantizer:
# measure loss and check if better than best # measure loss and check if better than best
loss = (org_out - out).float().pow(2).mean().item() # NOTE: float prevents overflow loss = (org_out - out).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss) history.append(loss)
is_best = loss < best_error if loss < best_error:
if is_best:
best_error = loss best_error = loss
best_ratio = ratio best_ratio = ratio
best_scales = scales best_scales = scales.clone()
previous_layer.load_state_dict(org_sd) previous_layer.load_state_dict(org_sd)
if best_ratio == -1: if best_ratio == -1:
logging.debug(history) logging.debug(history)
raise Exception raise Exception
best_scales = best_scales.view(-1)
assert torch.isnan(best_scales).sum() == 0, best_scales assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach() return best_scales.detach()
def init_quant(self, n_samples=128, seqlen=512): def init_quant(self, n_samples=128, seqlen=512):
layers = self.get_model_layers(self.model) layers = self.get_model_layers(self.model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment