Unverified Commit 63d2aaec authored by Casper's avatar Casper Committed by GitHub
Browse files

Robust quantization for Catcher (#209)

parent e440c7ac
...@@ -319,8 +319,16 @@ class AwqQuantizer: ...@@ -319,8 +319,16 @@ class AwqQuantizer:
super().__init__() super().__init__()
self.module = module self.module = module
def forward(self, hijacked_inputs, **kwargs): def forward(self, *args, **kwargs):
inps.append(hijacked_inputs) # assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)
inps.append(hidden_states)
layer_kwargs.update(kwargs) layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference raise ValueError # early exit to break later inference
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment