Unverified Commit 2350a4d0 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

Fix quantization issue with transformers >= 4.36.0 (#264)

parent 9c3dfa07
import torch import torch
import inspect
import logging import logging
import functools import functools
import torch.nn as nn import torch.nn as nn
...@@ -170,14 +171,16 @@ class AwqQuantizer: ...@@ -170,14 +171,16 @@ class AwqQuantizer:
# [STEP 3]: Compute output of module # [STEP 3]: Compute output of module
with torch.no_grad(): with torch.no_grad():
fp16_output = module2inspect(inp, **kwargs) module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
fp16_output = module2inspect(inp, **module_kwargs)
if isinstance(fp16_output, tuple): if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0] fp16_output = fp16_output[0]
# [STEP 4]: Compute loss # [STEP 4]: Compute loss
best_scales = self._compute_best_scale( best_scales = self._compute_best_scale(
inp, w_max, x_max, module2inspect, inp, w_max, x_max, module2inspect,
layers, fp16_output, kwargs layers, fp16_output, module_kwargs
) )
return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales) return (get_op_name(module, prev_op), tuple([get_op_name(module, m) for m in layers]), best_scales)
...@@ -390,10 +393,36 @@ class AwqQuantizer: ...@@ -390,10 +393,36 @@ class AwqQuantizer:
feat_dict=input_feat))) feat_dict=input_feat)))
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input # get output as next layer's input
self.inps = layer(self.inps, **self.module_kwargs)[0]
# Sanitize the kwargs in case we use transformers version that contains
# kwargs that are not handled by the module.
# Useful for trust_remote_code models.
module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
self.inps = layer(self.inps, **module_kwargs)[0]
for h in handles: for h in handles:
h.remove() h.remove()
# now solve for scaling and clipping # now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()} input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat return input_feat
def _sanitize_kwargs(self, inputs_kwargs, module):
"""
Remove the arguments that are not supported in the module's
forward pass to avoid breaking behaviour between different versions
of transformers.
Args:
inputs_kwargs (`dict`):
The input dictionary to pass to the model layer
module (`torch.nn.Module`):
Target module to quantize.
"""
module_signature = inspect.signature(module.forward).parameters
sanitized_kwargs = {}
for k, v in inputs_kwargs.items():
if k in module_signature:
sanitized_kwargs[k] = v
return sanitized_kwargs
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment