mpt.py 430 Bytes
Newer Older
yangql's avatar
yangql committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
from auto_gptq.modeling import BaseGPTQForCausalLM


class MPTGPTQForCausalLM(BaseGPTQForCausalLM):
    layer_type = "MPTBlock"
    layers_block_name = "transformer.blocks"
    outside_layer_modules = [
        "transformer.wte",  "transformer.norm_f"
    ]

    inside_layer_modules = [
        ["attn.Wqkv"],
        ["attn.out_proj"],
        ["ffn.up_proj"],
        ["ffn.down_proj"]
    ]


__all__ = ["MPTGPTQForCausalLM"]