Unverified Commit 2350a4d0 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix quantization issue with transformers >= 4.36.0 (#264)

parent 9c3dfa07
import torch
import inspect
import logging
import functools
import torch.nn as nn
......@@ -170,14 +171,16 @@ class AwqQuantizer:
# [STEP 3]: Compute output of module
with torch.no_grad():
fp16_output = module2inspect(inp, **kwargs)
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
fp16_output = module2inspect(inp, **module_kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect,
layers, fp16_output, kwargs
layers, fp16_output, module_kwargs
)
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
......@@ -390,10 +393,36 @@ class AwqQuantizer:
feat_dict=input_feat)))
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
self.inps = layer(self.inps, **self.module_kwargs)[0]
# Sanitize the kwargs in case we use transformers version that contains
# kwargs that are not handled by the module.
# Useful for trust_remote_code models.
module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
self.inps = layer(self.inps, **module_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
def _sanitize_kwargs(self, inputs_kwargs, module):
"""
Remove the arguments that are not supported in the module's
forward pass to avoid breaking behaviour between different versions
of transformers.
Args:
inputs_kwargs (`dict`):
The input dictionary to pass to the model layer
module (`torch.nn.Module`):
Target module to quantize.
"""
module_signature = inspect.signature(module.forward).parameters
sanitized_kwargs = {}
for k, v in inputs_kwargs.items():
if k in module_signature:
sanitized_kwargs[k] = v
return sanitized_kwargs
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment