Commit 6f30f051 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add OPT support

parent 471f811b
from .mpt import MptAWQForCausalLM
from .llama import LlamaAWQForCausalLM
from .opt import OptAWQForCausalLM
\ No newline at end of file
......@@ -4,7 +4,8 @@ from awq.models.base import BaseAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP = {
"mpt": MptAWQForCausalLM,
'llama': LlamaAWQForCausalLM
"llama": LlamaAWQForCausalLM,
"opt": OptAWQForCausalLM
}
def check_and_get_model_type(model_dir, trust_remote_code=True):
......
from .base import BaseAWQForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class OptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "OPTDecoderLayer"
@staticmethod
def get_model_layers(model: OPTForCausalLM):
return model.model.decoder.layers
@staticmethod
def get_act_for_scaling(module: OPTDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: OPTForCausalLM, device: str):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
@staticmethod
def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.self_attn_layer_norm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))
# linear 1
layers.append(dict(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))
# linear 2
layers.append(dict(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))
return layers
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment