Unverified Commit e9f62694 authored by Isotr0py's avatar Isotr0py Committed by GitHub
Browse files

Add StableLM support (#410)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent 33dfb048
...@@ -15,4 +15,5 @@ from .llava import LlavaAWQForCausalLM ...@@ -15,4 +15,5 @@ from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM from .mixtral import MixtralAWQForCausalLM
from .qwen2 import Qwen2AWQForCausalLM from .qwen2 import Qwen2AWQForCausalLM
from .gemma import GemmaAWQForCausalLM from .gemma import GemmaAWQForCausalLM
from .stablelm import StableLmAWQForCausalLM
from .starcoder2 import Starcoder2AWQForCausalLM from .starcoder2 import Starcoder2AWQForCausalLM
...@@ -24,6 +24,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -24,6 +24,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"llava": LlavaAWQForCausalLM, "llava": LlavaAWQForCausalLM,
"qwen2": Qwen2AWQForCausalLM, "qwen2": Qwen2AWQForCausalLM,
"gemma": GemmaAWQForCausalLM, "gemma": GemmaAWQForCausalLM,
"stablelm": StableLmAWQForCausalLM,
"starcoder2": Starcoder2AWQForCausalLM, "starcoder2": Starcoder2AWQForCausalLM,
} }
......
...@@ -68,6 +68,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = { ...@@ -68,6 +68,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"llava": "AutoModelForVision2Seq", "llava": "AutoModelForVision2Seq",
"qwen2": "AutoModelForCausalLM", "qwen2": "AutoModelForCausalLM",
"gemma": "AutoModelForCausalLM", "gemma": "AutoModelForCausalLM",
"stablelm": "AutoModelForCausalLM",
"starcoder2": "AutoModelForCausalLM", "starcoder2": "AutoModelForCausalLM",
} }
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.stablelm import StableLmForCausalLM as OldStableLmForCausalLM
from transformers.models.stablelm.modeling_stablelm import (
StableLmDecoderLayer as OldStableLmDecoderLayer,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class StableLmAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "StableLmDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldStableLmForCausalLM):
fuser = StableLmFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldStableLmForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldStableLmForCausalLM):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldStableLmForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(
module: OldStableLmDecoderLayer, input_feat, module_kwargs
):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class StableLmFuser:
def __init__(self, model: OldStableLmForCausalLM):
self.model = model
self.stablelm_blocks: List[Tuple[str, OldStableLmDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "StableLmDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldStableLmDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = module.input_layernorm
norm_2 = module.post_attention_layernorm
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
partial_rotary_factor=self.model.config.partial_rotary_factor,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
...@@ -29,9 +29,7 @@ class RoPE(nn.Module): ...@@ -29,9 +29,7 @@ class RoPE(nn.Module):
super(RoPE, self).__init__() super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter( self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis( self.precompute_freqs_cis(head_dim, max_seq_len * 2, rope_theta).to(device),
head_dim, max_seq_len * 2, rope_theta
).to(device),
requires_grad=False, requires_grad=False,
) )
...@@ -118,6 +116,7 @@ class QuantAttentionFused(nn.Module): ...@@ -118,6 +116,7 @@ class QuantAttentionFused(nn.Module):
use_alibi=False, use_alibi=False,
attention_shapes=None, attention_shapes=None,
rope_theta=10000, rope_theta=10000,
partial_rotary_factor=1.0,
head_dim=None, head_dim=None,
**kwargs **kwargs
): ):
...@@ -167,8 +166,9 @@ class QuantAttentionFused(nn.Module): ...@@ -167,8 +166,9 @@ class QuantAttentionFused(nn.Module):
self.is_neox = False self.is_neox = False
else: else:
self.alibi = None self.alibi = None
self.rope = RoPE(self.head_dim, max_seq_len, dev, rope_theta) self.partial_rotary_factor = partial_rotary_factor
self.rotary_dim = self.head_dim self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta)
self.is_neox = True self.is_neox = True
def forward( def forward(
...@@ -209,12 +209,26 @@ class QuantAttentionFused(nn.Module): ...@@ -209,12 +209,26 @@ class QuantAttentionFused(nn.Module):
xk = self.attention_shapes["xk_slice"](xqkv) xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv) xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1 or not FT_INSTALLED: if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"]) xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"]) xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"]) xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi: if not self.use_alibi:
# Partial rotary embedding
if self.partial_rotary_factor < 1:
xq_rot, xq_pass = (
xq[..., : self.rotary_dim],
xq[..., self.rotary_dim :],
)
xk_rot, xk_pass = (
xk[..., : self.rotary_dim],
xk[..., self.rotary_dim :],
)
xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen)
xq = torch.cat((xq_rot, xq_pass), dim=-1)
xk = torch.cat((xk_rot, xk_pass), dim=-1)
else:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen) xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
values_store = xv.transpose(2, 1) values_store = xv.transpose(2, 1)
......
...@@ -79,6 +79,7 @@ class LlamaLikeBlock(nn.Module): ...@@ -79,6 +79,7 @@ class LlamaLikeBlock(nn.Module):
dev, dev,
max_seq_len, max_seq_len,
rope_theta=10000, rope_theta=10000,
partial_rotary_factor=1.0,
use_alibi=False, use_alibi=False,
head_dim=None, head_dim=None,
): ):
...@@ -103,6 +104,7 @@ class LlamaLikeBlock(nn.Module): ...@@ -103,6 +104,7 @@ class LlamaLikeBlock(nn.Module):
max_seq_len=max_seq_len, max_seq_len=max_seq_len,
use_alibi=use_alibi, use_alibi=use_alibi,
rope_theta=rope_theta, rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
head_dim=head_dim, head_dim=head_dim,
).to(dev) ).to(dev)
self.norm_2 = norm_2.to(dev) self.norm_2 = norm_2.to(dev)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment