Commit 6b1c96c7 authored by EC2 Default User's avatar EC2 Default User
Browse files

fixed catcher input name

parent fac1af55
...@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -113,8 +113,8 @@ class BaseAWQForCausalLM(nn.Module):
super().__init__() super().__init__()
self.module = module self.module = module
def forward(self, inp, **kwargs): def forward(self, hidden_states, **kwargs):
inps.append(inp) inps.append(hidden_states)
layer_kwargs.update(kwargs) layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference raise ValueError # early exit to break later inference
......
...@@ -5,7 +5,7 @@ import torch.nn as nn ...@@ -5,7 +5,7 @@ import torch.nn as nn
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu from transformers.models.bloom.modeling_bloom import BloomBlock, BloomGelu
from transformers.models.opt.modeling_opt import OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTDecoderLayer
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMSNorm
from transformers.activations import NewGELUActivation
from .qmodule import ScaledActivation from .qmodule import ScaledActivation
from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name from awq.utils.module import get_op_by_name, get_op_name, set_op_by_name
...@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales): ...@@ -79,7 +79,7 @@ def scale_fc_fc(fc1, fc2, scales):
@torch.no_grad() @torch.no_grad()
def scale_gelu_fc(gelu, fc, scales): def scale_gelu_fc(gelu, fc, scales):
assert isinstance(gelu, nn.GELU) or isinstance(gelu, BloomGelu) assert any(isinstance(gelu,t) for t in [nn.GELU, BloomGelu, NewGELUActivation])
assert isinstance(fc, nn.Linear) assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device)) fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
...@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -195,7 +195,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)): elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.GELU) or isinstance(prev_op, BloomGelu): elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales) new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module) set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales) scale_gelu_fc(prev_op, layers[0], scales)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment