Unverified Commit cc75d0e8 authored by Junyang Lin's avatar Junyang Lin Committed by GitHub
Browse files

Add qwen2 (#321)

parent 34085edc
...@@ -13,3 +13,4 @@ from .qwen import QwenAWQForCausalLM ...@@ -13,3 +13,4 @@ from .qwen import QwenAWQForCausalLM
from .baichuan import BaichuanAWQForCausalLM from .baichuan import BaichuanAWQForCausalLM
from .llava import LlavaAWQForCausalLM from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM from .mixtral import MixtralAWQForCausalLM
from .qwen2 import Qwen2AWQForCausalLM
...@@ -21,6 +21,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -21,6 +21,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"qwen": QwenAWQForCausalLM, "qwen": QwenAWQForCausalLM,
"baichuan": BaichuanAWQForCausalLM, "baichuan": BaichuanAWQForCausalLM,
"llava": LlavaAWQForCausalLM, "llava": LlavaAWQForCausalLM,
"qwen2": Qwen2AWQForCausalLM
} }
......
...@@ -58,6 +58,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = { ...@@ -58,6 +58,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"qwen": "AutoModelForCausalLM", "qwen": "AutoModelForCausalLM",
"baichuan": "AutoModelForCausalLM", "baichuan": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq", "llava": "AutoModelForVision2Seq",
"qwen2": "AutoModelForCausalLM",
} }
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer as OldQwen2DecoderLayer,
Qwen2ForCausalLM as OldQwen2ForCausalLM
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Qwen2DecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldQwen2ForCausalLM):
fuser = Qwen2Fuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldQwen2ForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldQwen2DecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
return layers
class Qwen2Fuser:
def __init__(self, model: OldQwen2ForCausalLM):
self.model = model
self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'Qwen2DecoderLayer'.lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldQwen2DecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment