Commit e2858cef authored by Casper Hansen's avatar Casper Hansen
Browse files

Refactor model class. Add LLaMa support (LLaMa-2, Vicuna, etc).

parent d430694b
from .mpt import MptAWQForCausalLM
\ No newline at end of file
from .mpt import MptAWQForCausalLM
from .llama import LlamaAWQForCausalLM
\ No newline at end of file
from transformers import AutoConfig
from awq.models import MptAWQForCausalLM
from awq.models import *
from awq.models.base import BaseAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP = {
"mpt": MptAWQForCausalLM,
'llama': LlamaAWQForCausalLM
}
def check_and_get_model_type(model_dir, trust_remote_code=True):
......
......@@ -270,15 +270,15 @@ class BaseAWQForCausalLM:
@staticmethod
def _scale_activations(self, layer):
act_function = self.get_act_from_layer(layer)
scale_dict = self.get_act_for_scaling(layer)
if act_function is not None and not isinstance(act_function, ScaledActivation):
param = next(layer.parameters())
if scale_dict['is_scalable']:
if not isinstance(scale_dict['scale_layer'], ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_dict = self.get_act_for_scaling(layer)
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# get activation scale
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
\ No newline at end of file
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
\ No newline at end of file
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
@staticmethod
def get_model_layers(model: LlamaForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: LlamaDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: LlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# fc1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# fc2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
return layers
\ No newline at end of file
......@@ -7,6 +7,20 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
def get_model_layers(model):
return model.transformer.blocks
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=True,
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
)
@staticmethod
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
......@@ -42,21 +56,4 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
inp=input_feat['ffn.down_proj']
))
return layers
@staticmethod
def get_act_from_layer(layer):
return layer.ffn.act
@staticmethod
def get_act_for_scaling(module):
return dict(
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
)
@staticmethod
def move_embed(model, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
\ No newline at end of file
return layers
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment