Unverified Commit 0e77dbc1 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix MPT (#206)

parent 87350fef
...@@ -32,6 +32,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -32,6 +32,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs): def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
layers = [] layers = []
if module_kwargs.get("output_attentions") is not None:
module_kwargs.pop("output_attentions")
# attention input # attention input
layers.append(dict( layers.append(dict(
prev_op=module.norm_1, prev_op=module.norm_1,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment