Unverified Commit cef9f113 authored by Aoyu's avatar Aoyu Committed by GitHub
Browse files

Add Baichuan2 Support (#247)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent 9e8e28b2
...@@ -10,5 +10,6 @@ from .gpt_neox import GPTNeoXAWQForCausalLM ...@@ -10,5 +10,6 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from .aquila import AquilaAWQForCausalLM from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM from .qwen import QwenAWQForCausalLM
from .baichuan import BaichuanAWQForCausalLM
from .llava import LlavaAWQForCausalLM from .llava import LlavaAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM from .mixtral import MixtralAWQForCausalLM
...@@ -19,6 +19,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -19,6 +19,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"aquila": AquilaAWQForCausalLM, "aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM, "Yi": YiAWQForCausalLM,
"qwen": QwenAWQForCausalLM, "qwen": QwenAWQForCausalLM,
"baichuan": BaichuanAWQForCausalLM,
"llava": LlavaAWQForCausalLM, "llava": LlavaAWQForCausalLM,
} }
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "BaichuanLayer"
max_new_tokens_key = "model_max_length"
@staticmethod
def fuse_layers(model):
fuser = BaichuanFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
# def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.W_pack],
inp=input_feat['self_attn.W_pack'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# # attention out
# # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
# if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
# layers.append(dict(
# prev_op=module.self_attn.v_proj,
# layers=[module.self_attn.o_proj],
# inp=input_feat['self_attn.o_proj'],
# ))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
layers.append(dict(
prev_op=module.self_attn.W_pack,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
return layers
class BaichuanFuser:
def __init__(self, model):
self.model = model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
# qkv = fuse_qkv(
# module,
# module.self_attn.q_proj,
# module.self_attn.k_proj,
# module.self_attn.v_proj
# )
qkv = module.self_attn.W_pack
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_attention_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens,
use_alibi=True
))
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
...@@ -55,6 +55,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = { ...@@ -55,6 +55,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"aquila": "AutoModelForCausalLM", "aquila": "AutoModelForCausalLM",
"Yi": "AutoModelForCausalLM", "Yi": "AutoModelForCausalLM",
"qwen": "AutoModelForCausalLM", "qwen": "AutoModelForCausalLM",
"baichuan": "AutoModelForCausalLM",
"llava": "AutoModelForVision2Seq", "llava": "AutoModelForVision2Seq",
} }
...@@ -90,6 +91,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -90,6 +91,7 @@ class BaseAWQForCausalLM(nn.Module):
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert
) )
quantizer.quantize() quantizer.quantize()
self.is_quantized = True self.is_quantized = True
@staticmethod @staticmethod
......
...@@ -43,7 +43,7 @@ class LlamaLikeBlock(nn.Module): ...@@ -43,7 +43,7 @@ class LlamaLikeBlock(nn.Module):
""" """
def __init__( def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta mlp, norm_1, norm_2, dev, max_seq_len, rope_theta, use_alibi=False
): ):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
...@@ -52,7 +52,7 @@ class LlamaLikeBlock(nn.Module): ...@@ -52,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
self.norm_1 = norm_1.to(dev) self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta
).to(dev) ).to(dev)
self.norm_2 = norm_2.to(dev) self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev) self.mlp = mlp.to(dev)
...@@ -185,4 +185,4 @@ class FalconDecoderLayer(nn.Module): ...@@ -185,4 +185,4 @@ class FalconDecoderLayer(nn.Module):
out = h_attn + h_mlp out = h_attn + h_mlp
return out, None, past_key_value return out, None, past_key_value
\ No newline at end of file
...@@ -3,7 +3,7 @@ import logging ...@@ -3,7 +3,7 @@ import logging
from typing import List, Union from typing import List, Union
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data: Union[str, List[str]] = "pileval", def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None, n_samples=512, block_size=512, tokenizer=None, n_samples=512, block_size=512,
split="train", text_column="text"): split="train", text_column="text"):
if isinstance(data, str): if isinstance(data, str):
...@@ -15,18 +15,30 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval", ...@@ -15,18 +15,30 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
dataset = dataset.shuffle(seed=42) dataset = dataset.shuffle(seed=42)
elif isinstance(data, list): elif isinstance(data, list):
dataset = [{text_column: text} for text in data] if isinstance(data[0], str):
dataset = [{text_column: text} for text in data]
elif isinstance(data[0][0], int):
dataset = data
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.")
else: else:
raise NotImplementedError( raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list" "Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element.") "that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.")
samples = [] samples = []
n_run = 0 n_run = 0
for data in dataset: for data in dataset:
line = data[text_column] if isinstance(data, list):
line = line.strip() line_encoded = data
line_encoded = tokenizer.encode(line) else:
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512: if len(line_encoded) > 512:
continue continue
sample = torch.tensor([line_encoded]) sample = torch.tensor([line_encoded])
......
...@@ -156,6 +156,7 @@ def main(args): ...@@ -156,6 +156,7 @@ def main(args):
{"context": 512, "n_generate": 512}, {"context": 512, "n_generate": 512},
{"context": 1024, "n_generate": 1024}, {"context": 1024, "n_generate": 1024},
{"context": 2048, "n_generate": 2048}, {"context": 2048, "n_generate": 2048},
{"context": 4096, "n_generate": 4096},
] ]
if args.generator == "torch": if args.generator == "torch":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment