Unverified Commit f713b888 authored by Oscar Savolainen's avatar Oscar Savolainen Committed by GitHub
Browse files

x_max -> x_mean and w_max -> w_mean name changes and some comments (#378)

parent d9dc8e56
...@@ -244,17 +244,23 @@ class AwqQuantizer: ...@@ -244,17 +244,23 @@ class AwqQuantizer:
# Put x on the right device # Put x on the right device
inp = inp.to(next(module2inspect.parameters()).device) inp = inp.to(next(module2inspect.parameters()).device)
# [STEP 1]: Compute maximum of weight # [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([_m.weight for _m in layers], dim=0) weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self.group_size) weight = weight.view(-1, self.group_size)
# Calculates the relative magnitude of the weights within each of the quantization groups,
# and rescales each group individually so that each group has weights on a 0-1 scale.
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True) w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
# Resizes the rescaled weight matrix back up to its original dimensions
w_scale = w_scale.view(org_shape) w_scale = w_scale.view(org_shape)
w_max = w_scale.mean(0) # Gets the average rescaled magnitude for each output channel
w_mean = w_scale.mean(0)
clear_memory(weight) clear_memory(weight)
# [STEP 2]: Compute maximum of x # [STEP 2]: Compute per-channel mean of the input activation
x_max = inp.abs().view(-1, inp.shape[-1]).mean(0) x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 3]: Compute output of module # [STEP 3]: Compute output of module
with torch.no_grad(): with torch.no_grad():
...@@ -266,7 +272,7 @@ class AwqQuantizer: ...@@ -266,7 +272,7 @@ class AwqQuantizer:
# [STEP 4]: Compute loss # [STEP 4]: Compute loss
best_scales = self._compute_best_scale( best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect, layers, fp16_output, module_kwargs inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
) )
return ( return (
...@@ -278,8 +284,8 @@ class AwqQuantizer: ...@@ -278,8 +284,8 @@ class AwqQuantizer:
def _compute_best_scale( def _compute_best_scale(
self, self,
x, x,
w_max, w_mean,
x_max, x_mean,
module2inspect, module2inspect,
linears2scale: List[nn.Linear], linears2scale: List[nn.Linear],
fp16_output, fp16_output,
...@@ -303,8 +309,8 @@ class AwqQuantizer: ...@@ -303,8 +309,8 @@ class AwqQuantizer:
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()} org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
device = x.device device = x.device
x_max = x_max.view(-1).to(device) x_mean = x_mean.view(-1).to(device)
w_max = w_max.view(-1).to(device) w_mean = w_mean.view(-1).to(device)
for ratio in range(n_grid): for ratio in range(n_grid):
# create new scales # create new scales
...@@ -312,9 +318,9 @@ class AwqQuantizer: ...@@ -312,9 +318,9 @@ class AwqQuantizer:
# NOTE: s^-1 * x is fused here, according to paper # NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling: if self.duo_scaling:
scales = (x_max.pow(ratio) / w_max.pow(1 - ratio)).clamp(min=1e-4) scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4)
else: else:
scales = x_max.pow(ratio).clamp(min=1e-4).view(-1) scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device) scales_view = scales.view(1, -1).to(device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment