from .base import BaseAWQForCausalLM class MptAWQForCausalLM(BaseAWQForCausalLM): layer_type = "MPTBlock" def get_model_layers(model): return model.transformer.blocks def get_layers_for_scaling(module, input_feat, module_kwargs): layers = [] # attention input layers.append(dict( prev_op=module.norm_1, layers=[module.attn.Wqkv], inp=input_feat['attn.Wqkv'], module2inspect=module.attn, kwargs=module_kwargs )) # attention output layers.append(dict( prev_op=module.attn.Wqkv, layers=[module.attn.out_proj], inp=input_feat['attn.out_proj'] )) # linear 1 layers.append(dict( prev_op=module.norm_2, layers=[module.ffn.up_proj], inp=input_feat['ffn.up_proj'], module2inspect=module.ffn )) # linear 2 layers.append(dict( prev_op=module.ffn.act, layers=[module.ffn.down_proj], inp=input_feat['ffn.down_proj'] )) return layers def get_act_for_scaling(module): return dict( scale_name="ffn.act", scale_layer=module.ffn.act, scale_shape=module.ffn.up_proj.out_features ) def move_embed(model, device): model.transformer.wte = model.transformer.wte.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)