Commit c9d01ac3 authored by Casper Hansen's avatar Casper Hansen
Browse files

Better comments. Implement cosine similarity

parent e205548d
...@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module):
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval", calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text"): split="train", text_column="text", loss_objective='mse'):
self.quant_config = quant_config self.quant_config = quant_config
quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"] quant_config["version"] = "GEMM" if 'version' not in quant_config.keys() else quant_config["version"]
quantizer = AwqQuantizer( quantizer = AwqQuantizer(
self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"], self, self.model, tokenizer, quant_config["w_bit"], quant_config["q_group_size"],
quant_config["version"], calib_data, split, text_column quant_config["version"], calib_data, split, text_column, loss_objective
) )
quantizer.quantize() quantizer.quantize()
self.is_quantized = True self.is_quantized = True
......
...@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, ...@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class AwqQuantizer: class AwqQuantizer:
def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version, calib_data, split, text_column) -> None: def __init__(self, awq_model, model, tokenizer, w_bit, group_size, version,
calib_data, split, text_column, loss_objective='mse') -> None:
self.awq_model = awq_model self.awq_model = awq_model
self.model = model self.model = model
self.tokenizer = tokenizer self.tokenizer = tokenizer
...@@ -23,6 +24,7 @@ class AwqQuantizer: ...@@ -23,6 +24,7 @@ class AwqQuantizer:
self.split = split self.split = split
self.text_column = text_column self.text_column = text_column
self.modules, self.module_kwargs, self.inps = self.init_quant() self.modules, self.module_kwargs, self.inps = self.init_quant()
self.loss_objective = loss_objective
def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False): def pseudo_quantize_tensor(self, w: torch.Tensor, get_scale_zp=False):
org_w_shape = w.shape org_w_shape = w.shape
...@@ -74,37 +76,39 @@ class AwqQuantizer: ...@@ -74,37 +76,39 @@ class AwqQuantizer:
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".") clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights # [STEP 4]: Quantize weights
for name, linear_layer in named_linears.items(): self._apply_quant(self.modules[i], named_linears)
# NOTE: small regression in perplexity if linear layer uses .cpu().float() clear_memory()
linear_layer = linear_layer.cuda().half()
def _apply_quant(self, module, named_linears: dict[str, nn.Linear]):
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor( for name, linear_layer in named_linears.items():
linear_layer.weight.data, # NOTE: small regression in perplexity if linear layer uses .cpu().float()
get_scale_zp=True linear_layer = linear_layer.cuda().half()
)
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
if self.version == 'GEMM': linear_layer.weight.data,
scales = scales.t().contiguous() get_scale_zp=True
zeros = zeros.t().contiguous() )
q_linear_module = WQLinear_GEMM
if self.version == 'GEMM':
elif self.version == 'GEMV': scales = scales.t().contiguous()
q_linear_module = WQLinear_GEMV zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
q_linear = q_linear_module.from_linear(
linear=linear_layer, elif self.version == 'GEMV':
w_bit=self.w_bit, q_linear_module = WQLinear_GEMV
group_size=self.group_size,
init_only=False,
scales=scales,
zeros=zeros
)
linear_layer.cpu()
q_linear.to(next(self.modules[i].parameters()).device)
set_op_by_name(self.modules[i], name, q_linear)
clear_memory()
q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=self.w_bit,
group_size=self.group_size,
init_only=False,
scales=scales,
zeros=zeros
)
linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)
clear_memory() clear_memory()
@torch.no_grad() @torch.no_grad()
...@@ -133,19 +137,20 @@ class AwqQuantizer: ...@@ -133,19 +137,20 @@ class AwqQuantizer:
# [STEP 3]: Compute output of module # [STEP 3]: Compute output of module
with torch.no_grad(): with torch.no_grad():
org_out = module2inspect(inp, **kwargs) fp16_output = module2inspect(inp, **kwargs)
if isinstance(org_out, tuple): if isinstance(fp16_output, tuple):
org_out = org_out[0] fp16_output = fp16_output[0]
# [STEP 4]: Compute loss # [STEP 4]: Compute loss
best_scales = self._compute_best_scale( best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect, inp, w_max, x_max, module2inspect,
layers, org_out, kwargs layers, fp16_output, kwargs
) )
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales) return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear], org_out, kwargs={}): def _compute_best_scale(self, x, w_max, x_max, module2inspect, linears2scale: list[nn.Linear],
fp16_output, kwargs={}):
""" """
Compute loss and select best scales Compute loss and select best scales
...@@ -170,20 +175,29 @@ class AwqQuantizer: ...@@ -170,20 +175,29 @@ class AwqQuantizer:
for ratio in range(n_grid): for ratio in range(n_grid):
# create new scales # create new scales
ratio = ratio / n_grid ratio = ratio / n_grid
# s^-1
scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4) scales = (x_max.pow(ratio) / w_max.pow(1-ratio)).clamp(min=1e-4)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device) scales_view = scales.view(1, -1).to(device)
# NOTE: s^-1 * x is fused here, according to paper
for fc in linears2scale: for fc in linears2scale:
# Q(W * s)
fc.weight.mul_(scales_view) fc.weight.mul_(scales_view)
fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view fc.weight.data = self.pseudo_quantize_tensor(fc.weight.data) / scales_view
out = module2inspect(x, **kwargs) # W * X
if isinstance(out, tuple): int_w_output = module2inspect(x, **kwargs)
out = out[0] if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
if self.loss_objective == 'mse': # (L2 norm)
loss = (fp16_output - int_w_output).float().pow(2).mean().item() # NOTE: float prevents overflow
elif self.loss_objective == 'cosine':
loss = -nn.functional.cosine_similarity(fp16_output, int_w_output).mean().item()
# measure loss and check if better than best
loss = (org_out - out).float().pow(2).mean().item() # NOTE: float prevents overflow
history.append(loss) history.append(loss)
if loss < best_error: if loss < best_error:
best_error = loss best_error = loss
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment