Unverified Commit 33dfb048 authored by 少年's avatar 少年 Committed by GitHub
Browse files

add starcoder2 support (#406)


Co-authored-by: default avatarcharrli <charrli@tencent.com>
parent eb85f67d
...@@ -15,3 +15,4 @@ from .llava import LlavaAWQForCausalLM ...@@ -15,3 +15,4 @@ from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM from .mixtral import MixtralAWQForCausalLM
from .qwen2 import Qwen2AWQForCausalLM from .qwen2 import Qwen2AWQForCausalLM
from .gemma import GemmaAWQForCausalLM from .gemma import GemmaAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM
\ No newline at end of file
...@@ -24,6 +24,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -24,6 +24,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"llava": LlavaAWQForCausalLM, "llava": LlavaAWQForCausalLM,
"qwen2": Qwen2AWQForCausalLM, "qwen2": Qwen2AWQForCausalLM,
"gemma": GemmaAWQForCausalLM, "gemma": GemmaAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM,
} }
......
...@@ -68,6 +68,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = { ...@@ -68,6 +68,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"llava": "AutoModelForVision2Seq", "llava": "AutoModelForVision2Seq",
"qwen2": "AutoModelForCausalLM", "qwen2": "AutoModelForCausalLM",
"gemma": "AutoModelForCausalLM", "gemma": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM",
} }
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2ForCausalLM as OldStarcoder2ForCausalLM,
Starcoder2DecoderLayer as OldStarcoder2DecoderLayer,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class Starcoder2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Starcoder2DecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldStarcoder2ForCausalLM):
fuser = Starcoder2Fuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldStarcoder2ForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldStarcoder2DecoderLayer):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.c_fc.out_features,
)
# return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldStarcoder2ForCausalLM, device):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.c_fc],
inp=input_feat["mlp.c_fc"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.act,
layers=[module.mlp.c_proj],
inp=input_feat["mlp.c_proj"],
)
)
return layers
class Starcoder2Fuser:
def __init__(self, model: OldStarcoder2ForCausalLM):
self.model = model
self.starcoder2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "Starcoder2DecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldStarcoder2DecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.eps
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.eps,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment