Commit e2858cef authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor model class. Add LLaMa support (LLaMa-2, Vicuna, etc).

parent d430694b
from .mpt import MptAWQForCausalLM from .mpt import MptAWQForCausalLM
\ No newline at end of file from .llama import LlamaAWQForCausalLM
\ No newline at end of file
from transformers import AutoConfig from transformers import AutoConfig
from awq.models import MptAWQForCausalLM from awq.models import *
from awq.models.base import BaseAWQForCausalLM from awq.models.base import BaseAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP = { AWQ_CAUSAL_LM_MODEL_MAP = {
"mpt": MptAWQForCausalLM, "mpt": MptAWQForCausalLM,
'llama': LlamaAWQForCausalLM
} }
def check_and_get_model_type(model_dir, trust_remote_code=True): def check_and_get_model_type(model_dir, trust_remote_code=True):
......
...@@ -270,15 +270,15 @@ class BaseAWQForCausalLM: ...@@ -270,15 +270,15 @@ class BaseAWQForCausalLM:
@staticmethod @staticmethod
def _scale_activations(self, layer): def _scale_activations(self, layer):
act_function = self.get_act_from_layer(layer) scale_dict = self.get_act_for_scaling(layer)
if act_function is not None and not isinstance(act_function, ScaledActivation): if scale_dict['is_scalable']:
param = next(layer.parameters()) if not isinstance(scale_dict['scale_layer'], ScaledActivation):
param = next(layer.parameters())
# get activation scale # get activation scale
scale_dict = self.get_act_for_scaling(layer) scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation # scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like) scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act) set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
\ No newline at end of file \ No newline at end of file
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
@staticmethod
def get_model_layers(model: LlamaForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: LlamaDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: LlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# fc2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
return layers
\ No newline at end of file
...@@ -7,6 +7,20 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -7,6 +7,20 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
def get_model_layers(model): def get_model_layers(model):
return model.transformer.blocks return model.transformer.blocks
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=True,
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
)
@staticmethod
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs): def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = [] layers = []
...@@ -42,21 +56,4 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -42,21 +56,4 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
inp=input_feat['ffn.down_proj'] inp=input_feat['ffn.down_proj']
)) ))
return layers return layers
\ No newline at end of file
@staticmethod
def get_act_from_layer(layer):
return layer.ffn.act
@staticmethod
def get_act_for_scaling(module):
return dict(
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
)
@staticmethod
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment