Commit fcd9637c authored by gaoqiong's avatar gaoqiong
Browse files

Merge branch 'v0.2.5_develop' into 'main'

v0.2.5

See merge request dcutoolkit/deeplearing/autoawq!2
parents 7724cca1 427f5481
import tqdm
import torch
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.gemma.modeling_gemma import (
GemmaDecoderLayer as OldGemmaDecoderLayer,
GemmaForCausalLM as OldGemmaForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class GemmaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GemmaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldGemmaDecoderLayer):
fuser = GemmaFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldGemmaForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldGemmaDecoderLayer):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldGemmaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldGemmaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class GemmaFuser:
def __init__(self, model: OldGemmaForCausalLM):
self.model = model
self.Gemma_blocks: List[Tuple[str, OldGemmaDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "GemmaDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldGemmaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
with torch.no_grad():
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
module.input_layernorm.weight += 1
module.post_attention_layernorm.weight += 1
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.eps
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.eps,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
head_dim=self.model.config.head_dim,
)
)
with torch.no_grad():
# Normalize Gemma's embedding layer
self.model.model.embed_tokens.weight *= self.model.config.hidden_size**0.5
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
from .base import BaseAWQForCausalLM
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import (
GPTBigCodeForCausalLM,
GPTBigCodeBlock as OldGptBigCodeBlock,
)
class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTBigCodeBlock"
max_seq_len_key = "n_positions"
@staticmethod
def get_model_layers(model: GPTBigCodeForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: OldGptBigCodeBlock):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.c_fc.out_features,
)
@staticmethod
def move_embed(model: GPTBigCodeForCausalLM, device):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.wpe = model.transformer.wpe.to(device)
model.transformer.drop = model.transformer.drop.to(device)
@staticmethod
def get_layers_for_scaling(module: OldGptBigCodeBlock, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.ln_1,
layers=[module.attn.c_attn],
inp=input_feat["attn.c_attn"],
module2inspect=module.attn,
kwargs=module_kwargs,
)
)
# linear 1
layers.append(
dict(
prev_op=module.ln_2,
layers=[module.mlp.c_fc],
inp=input_feat["mlp.c_fc"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.act,
layers=[module.mlp.c_proj],
inp=input_feat["mlp.c_proj"],
)
)
return layers
from .base import BaseAWQForCausalLM
from transformers.models.gpt_neox.modeling_gpt_neox import (
GPTNeoXLayer,
GPTNeoXForCausalLM,
)
class GPTNeoXAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTNeoXDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def get_model_layers(model: GPTNeoXForCausalLM):
return model.gpt_neox.layers
@staticmethod
def get_act_for_scaling(module: GPTNeoXLayer):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.dense_h_to_4h.out_features,
)
@staticmethod
def move_embed(model: GPTNeoXForCausalLM, device: str):
model.gpt_neox.embed_in = model.gpt_neox.embed_in.to(device)
@staticmethod
def get_layers_for_scaling(module: GPTNeoXLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[module.attention.query_key_value],
inp=input_feat["attention.query_key_value"],
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
layers.append(dict(
prev_op=module.attention.query_key_value,
layers=[module.attention.dense],
inp=input_feat['attention.dense'],
))
"""
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.dense_h_to_4h],
inp=input_feat["mlp.dense_h_to_4h"],
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.act,
layers=[module.mlp.dense_4h_to_h],
inp=input_feat["mlp.dense_4h_to_h"],
)
)
return layers
from .base import BaseAWQForCausalLM
from transformers.models.gptj.modeling_gptj import GPTJForCausalLM, GPTJBlock
class GPTJAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTJBlock"
max_seq_len_key = "n_positions"
@staticmethod
def get_model_layers(model: GPTJForCausalLM):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module: GPTJBlock):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.fc_in.out_features,
)
@staticmethod
def move_embed(model: GPTJForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)
@staticmethod
def get_layers_for_scaling(module: GPTJBlock, input_feat, module_kwargs):
layers = []
# attention input + linear 1
layers.append(
dict(
prev_op=module.ln_1,
layers=[
module.attn.q_proj,
module.attn.k_proj,
module.attn.v_proj,
module.mlp.fc_in,
],
inp=input_feat["attn.q_proj"],
module2inspect=module,
kwargs=module_kwargs,
)
)
# attention out
layers.append(
dict(
prev_op=module.attn.v_proj,
layers=[module.attn.out_proj],
inp=input_feat["attn.out_proj"],
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.act,
layers=[module.mlp.fc_out],
inp=input_feat["mlp.fc_out"],
)
)
return layers
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldLlamaForCausalLM):
fuser = LlamaFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldLlamaForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class LlamaFuser:
def __init__(self, model: OldLlamaForCausalLM):
self.model = model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
)
from transformers.models.llava.modeling_llava import (
LlavaForConditionalGeneration as OldLlavaForConditionalGeneration,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlavaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldLlavaForConditionalGeneration):
fuser = LlavaFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldLlavaForConditionalGeneration):
return model.language_model.model.layers
@staticmethod
def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldLlavaForConditionalGeneration, device: str):
model.language_model.model.embed_tokens = model.get_input_embeddings().to(
device
)
@staticmethod
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class LlavaFuser:
def __init__(self, model: OldLlavaForConditionalGeneration):
self.model = model.language_model
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "LlamaDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MistralDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldMistralForCausalLM):
fuser = MistralFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldMistralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldMistralDecoderLayer):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldMistralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(
module: OldMistralDecoderLayer, input_feat, module_kwargs
):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class MistralFuser:
def __init__(self, model: OldMistralForCausalLM):
self.model = model
self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "MistralDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldMistralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import tqdm
import torch
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel
from awq.modules.fused.moe import FusedSparseMoeBlock
from awq.utils.fused_utils import fuse_qkv, fuse_linears
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM,
)
from awq.modules.linear import WQLinear_GEMM
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_seq_len_key = "max_position_embeddings"
modules_to_not_convert = ["gate"]
@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(
module: OldMixtralDecoderLayer, input_feat, module_kwargs
):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear in
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[
w
for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat["block_sparse_moe"],
module2inspect=module.block_sparse_moe,
)
)
# linear out
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(
dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f"block_sparse_moe.experts.{i}.w2"],
)
)
return layers
class MixtralFuser:
def __init__(self, model: OldMixtralForCausalLM):
self.model = model
self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldMixtralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
sparse_moe = module.block_sparse_moe
if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM):
fused_w1w3s = [
fuse_linears(
[
sparse_moe.experts[i].w1,
sparse_moe.experts[i].w3,
],
device,
)
for i in range(len(sparse_moe.experts))
]
stacked_w1w3s = fuse_linears(
fused_w1w3s, device, dim=0, operation=torch.stack
)
stacked_w2s = fuse_linears(
[expert.w2 for expert in sparse_moe.experts],
device,
dim=0,
operation=torch.stack,
)
sparse_moe = FusedSparseMoeBlock(
top_k=sparse_moe.top_k,
gate=sparse_moe.gate,
ws=stacked_w1w3s,
w2s=stacked_w2s,
)
blocks.append(
MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)
model_norm = FasterTransformerRMSNorm(
self.model.model.norm.weight,
self.model.model.norm.variance_epsilon,
)
self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
model_norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
max_seq_len_key = "max_seq_len"
@staticmethod
def fuse_layers(model: MptForCausalLM):
fuser = MptFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks
@staticmethod
def get_act_for_scaling(module: OldMptBlock):
return dict(
is_scalable=True,
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features,
)
@staticmethod
def move_embed(model: MptForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod
def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
layers = []
if module_kwargs.get("output_attentions") is not None:
module_kwargs.pop("output_attentions")
# attention input
layers.append(
dict(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat["attn.Wqkv"],
module2inspect=module.attn,
kwargs=module_kwargs,
)
)
# attention output
layers.append(
dict(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat["attn.out_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat["ffn.up_proj"],
module2inspect=module.ffn,
)
)
# linear 2
layers.append(
dict(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat["ffn.down_proj"],
)
)
return layers
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel
class MptFuser:
def __init__(self, model: MptForCausalLM):
self.model = model
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module)
for name, module in self.model.named_modules()
if "mptblock" in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldMptBlock
for module in self.model.transformer.blocks:
blocks.append(
MPTBlock(
self.model.config.d_model,
self.model.config.n_heads,
module.attn.Wqkv,
module.attn.out_proj,
module.ffn,
module.norm_1,
module.norm_2,
next(iter(module.state_dict().values())).device,
self.model.config.max_seq_len,
)
)
self.model.transformer = MPTModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.wte,
self.model.transformer.norm_f,
)
setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
from .base import BaseAWQForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class OptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "OPTDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def get_model_layers(model: OPTForCausalLM):
return model.model.decoder.layers
@staticmethod
def get_act_for_scaling(module: OPTDecoderLayer):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OPTForCausalLM, device: str):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(
device
)
@staticmethod
def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.self_attn_layer_norm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat["self_attn.out_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat["fc1"],
)
)
# linear 2
layers.append(
dict(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat["fc2"],
)
)
return layers
from .base import BaseAWQForCausalLM
class QwenAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "QWenBlock"
max_seq_len_key = "seq_length"
@staticmethod
def get_model_layers(model):
return model.transformer.h
@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)
@staticmethod
def move_embed(model, device: str):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.rotary_emb = model.transformer.rotary_emb.to(device)
@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
# attention
layers.append(
dict(
prev_op=module.ln_1,
layers=[module.attn.c_attn],
inp=input_feat["attn.c_attn"],
module2inspect=module.attn,
kwargs=module_kwargs,
)
)
# mlp
layers.append(
dict(
prev_op=module.ln_2,
layers=[module.mlp.w2, module.mlp.w1],
inp=input_feat["mlp.w2"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.w1,
layers=[module.mlp.c_proj],
inp=input_feat["mlp.c_proj"],
)
)
return layers
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer as OldQwen2DecoderLayer,
Qwen2ForCausalLM as OldQwen2ForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Qwen2DecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldQwen2ForCausalLM):
fuser = Qwen2Fuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldQwen2ForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldQwen2DecoderLayer):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldQwen2DecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class Qwen2Fuser:
def __init__(self, model: OldQwen2ForCausalLM):
self.model = model
self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "Qwen2DecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldQwen2DecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.stablelm import StableLmForCausalLM as OldStableLmForCausalLM
from transformers.models.stablelm.modeling_stablelm import (
StableLmDecoderLayer as OldStableLmDecoderLayer,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class StableLmAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "StableLmDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldStableLmForCausalLM):
fuser = StableLmFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldStableLmForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldStableLmForCausalLM):
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldStableLmForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(
module: OldStableLmDecoderLayer, input_feat, module_kwargs
):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class StableLmFuser:
def __init__(self, model: OldStableLmForCausalLM):
self.model = model
self.stablelm_blocks: List[Tuple[str, OldStableLmDecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "StableLmDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldStableLmDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = module.input_layernorm
norm_2 = module.post_attention_layernorm
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
partial_rotary_factor=self.model.config.partial_rotary_factor,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.starcoder2.modeling_starcoder2 import (
Starcoder2ForCausalLM as OldStarcoder2ForCausalLM,
Starcoder2DecoderLayer as OldStarcoder2DecoderLayer,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class Starcoder2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Starcoder2DecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldStarcoder2ForCausalLM):
fuser = Starcoder2Fuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldStarcoder2ForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldStarcoder2DecoderLayer):
return dict(
is_scalable=True,
scale_name="mlp.act",
scale_layer=module.mlp.act,
scale_shape=module.mlp.c_fc.out_features,
)
# return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldStarcoder2ForCausalLM, device):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldStarcoder2DecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.c_fc],
inp=input_feat["mlp.c_fc"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.act,
layers=[module.mlp.c_proj],
inp=input_feat["mlp.c_proj"],
)
)
return layers
class Starcoder2Fuser:
def __init__(self, model: OldStarcoder2ForCausalLM):
self.model = model
self.starcoder2_blocks: List[Tuple[str, OldStarcoder2DecoderLayer]] = [
(name, module)
for name, module in self.model.named_modules()
if "Starcoder2DecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldStarcoder2DecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
# SC2 use normal LayerNorm
norm_1 = module.input_layernorm
norm_2 = module.post_attention_layernorm
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
\ No newline at end of file
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.norm import FasterTransformerRMSNorm
class YiAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "YiDecoderLayer"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model):
fuser = YiFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(is_scalable=False)
@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
# attention input
layers.append(
dict(
prev_op=module.ln1,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(
dict(
prev_op=module.ln2,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
class YiFuser:
def __init__(self, model):
self.model = model
self.yi_blocks: List[Tuple[str, object]] = [
(name, module)
for name, module in self.model.named_modules()
if "YiDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.ln1.weight, module.ln1.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.ln2.weight, module.ln2.variance_epsilon
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import torch.nn as nn
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
import os
import math
import torch
import torch.nn as nn
from torch.nn import functional as F
from awq.modules.fused.cache import WindowedCache
from awq.utils.fused_utils import get_attention_shapes
try:
import awq_ft_ext
FT_INSTALLED = True
except:
FT_INSTALLED = False
HF_NEW_CACHE_FORMAT = False
import transformers
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
from transformers.cache_utils import DynamicCache
class RoPE(nn.Module):
def __init__(self, head_dim, max_seq_len, device, rope_theta):
super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(head_dim, max_seq_len * 2, rope_theta).to(device),
requires_grad=False,
)
@staticmethod
def precompute_freqs_cis(dim: int, end: int, theta=10000.0):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end)
freqs = torch.outer(t, freqs).float()
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
return freqs_cis
@staticmethod
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
ndim = x.ndim
assert 0 <= 1 < ndim
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
return freqs_cis.view(*shape)
def forward(self, xq: torch.Tensor, xk: torch.Tensor, start_pos: int, seqlen: int):
xq_ = torch.view_as_complex(
xq.float().reshape(*xq.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
xk_ = torch.view_as_complex(
xk.float().reshape(*xk.shape[:-1], 2, -1).transpose(-2, -1).contiguous()
)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
super(ALiBi, self).__init__()
# Initialize ALiBi slopes and bias
slopes, bias = self.build_alibi_bias(
n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
)
self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)
@staticmethod
def gen_slopes(n_heads, alibi_bias_max=8):
_n_heads = 2 ** math.ceil(math.log2(n_heads))
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
@staticmethod
def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, 1, seq_len
)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen]
return scores
class QuantAttentionFused(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
dev,
max_seq_len=2048,
use_alibi=False,
attention_shapes=None,
rope_theta=10000,
partial_rotary_factor=1.0,
head_dim=None,
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.n_kv_groups = n_heads // n_kv_heads if n_kv_heads != 0 else 0
self.head_dim = head_dim
if head_dim is None:
self.head_dim = hidden_size // n_heads
self.qkv_proj = qkv_layer
self.o_proj = o_proj
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if kwargs.get("max_new_tokens") is not None:
max_seq_len = kwargs["max_new_tokens"]
self.max_seq_len = max_seq_len
self.is_hf_transformers = False
self.rope_theta = rope_theta
# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
attention_shapes,
max_seq_len,
self.cache_batch_size,
n_heads,
n_kv_heads,
self.head_dim,
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"],
self.attention_shapes["cache_k"],
self.max_seq_len,
dev,
)
if use_alibi:
self.alibi = ALiBi(n_heads, max_seq_len, dev)
self.rotary_dim = 0
self.is_neox = False
else:
self.alibi = None
self.partial_rotary_factor = partial_rotary_factor
self.rotary_dim = int(self.head_dim * self.partial_rotary_factor)
self.rope = RoPE(self.rotary_dim, max_seq_len, dev, rope_theta)
self.is_neox = True
def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
):
bsz, seqlen, _ = hidden_states.shape
# Reallocate cache if batch size changes
if bsz != self.cache_batch_size:
if bsz > self.cache_batch_size:
self.cache.increase_batch_size(bsz)
self.cache_batch_size = bsz
elif bsz < self.cache_batch_size:
self.cache.decrease_batch_size(bsz)
self.cache_batch_size = bsz
# Always reset to 0
self.start_pos = 0
hf_is_generating = False
hf_is_first_forward = "past_key_value" in kwargs and kwargs["past_key_value"] is None
hf_is_new_cache_first_forward = "past_key_value" in kwargs and isinstance(kwargs["past_key_value"], DynamicCache) and kwargs["past_key_value"].get_seq_length() == 0
if self.is_hf_transformers and "use_cache" in kwargs:
hf_is_generating = kwargs["use_cache"]
# print(kwargs["past_key_value"].get_seq_length())
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration
if (self.is_hf_transformers and (hf_is_first_forward or hf_is_new_cache_first_forward)) or (self.is_hf_transformers and not hf_is_generating):
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
if seqlen > 1 or self.partial_rotary_factor < 1 or not FT_INSTALLED:
xq = xq.view((bsz, seqlen) + self.attention_shapes["xq_view"])
xk = xk.view((bsz, seqlen) + self.attention_shapes["xk_view"])
xv = xv.view((bsz, seqlen) + self.attention_shapes["xv_view"])
if not self.use_alibi:
# Partial rotary embedding
if self.partial_rotary_factor < 1:
xq_rot, xq_pass = (
xq[..., : self.rotary_dim],
xq[..., self.rotary_dim :],
)
xk_rot, xk_pass = (
xk[..., : self.rotary_dim],
xk[..., self.rotary_dim :],
)
xq_rot, xk_rot = self.rope.forward(xq_rot, xk_rot, self.start_pos, seqlen)
xq = torch.cat((xq_rot, xq_pass), dim=-1)
xk = torch.cat((xk_rot, xk_pass), dim=-1)
else:
xq, xk = self.rope.forward(xq, xk, self.start_pos, seqlen)
values_store = xv.transpose(2, 1)
keys_store = (
xk.reshape((bsz, seqlen) + self.attention_shapes["xk_reshape"])
.permute(0, 2, 3, 1, 4)
.contiguous()
)
self.cache.to(xq)
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
# Only necessary to retrieve from cache when we are not processing context
if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
keys = xk
values = xv
if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(
values, dim=2, repeats=self.n_kv_groups
)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(self.head_dim)
if self.use_alibi:
scores = self.alibi.forward(scores, seqlen)
# When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1:
# For llama-arch, the causal mask is preallocated with bsz x 1 x max_seq_len x max_seq_len, thus we
# need to slice it
if attention_mask.shape[-1] != seqlen:
attention_mask = attention_mask[:, :, :seqlen, :seqlen]
scores = (
scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
else:
xq = xq.view((bsz,) + self.attention_shapes["single_xq_view"])
xk = xk.view((bsz,) + self.attention_shapes["single_xk_view"])
xv = xv.view((bsz,) + self.attention_shapes["single_xv_view"])
alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = awq_ft_ext.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
self.rope_theta, # rotary embedding base
self.is_neox, # is neox
)
attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen
if self.is_hf_transformers and not hf_is_generating:
self.start_pos = 0
# past_key_value is replaced with cache_v, cache_k, returning empty data
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
if HF_NEW_CACHE_FORMAT and self.is_hf_transformers:
new_cache = DynamicCache()
new_cache.update(past_key_value[0], past_key_value[0], layer_idx=0)
past_key_value = new_cache
return attn_output, attention_weight, past_key_value
import os
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MixtralBlock(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
moe,
norm_1,
norm_2,
dev,
max_seq_len,
rope_theta,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
rope_theta=rope_theta,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev
def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
)
h = hidden_states.to(attn_output.device) + attn_output
out = self.moe.forward(self.norm_2(h))
out = h + out
return out, None, past_key_value
class LlamaLikeBlock(nn.Module):
"""
LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
mlp,
norm_1,
norm_2,
dev,
max_seq_len,
rope_theta=10000,
partial_rotary_factor=1.0,
use_alibi=False,
head_dim=None,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.head_dim = hidden_size // n_heads
# To support gemma-7b, its head_dim is separate
if head_dim:
self.head_dim = head_dim
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
partial_rotary_factor=partial_rotary_factor,
head_dim=head_dim,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
self.device = dev
def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
)
h = hidden_states.to(attn_output.device) + attn_output
out = h + self.mlp.forward(self.norm_2(h))
return out, None, past_key_value
class MPTBlock(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
qkv_layer,
o_proj,
mpt_mlp,
norm_1,
norm_2,
dev,
max_seq_len,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 0
self.hidden_size = hidden_size
self.norm_1 = norm_1
self.attn = QuantAttentionFused(
hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=True,
).to(dev)
self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev)
self.device = dev
def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True,
)
h = hidden_states.to(attn_output.device) + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
class FalconDecoderLayer(nn.Module):
def __init__(
self,
hidden_size,
n_heads,
qkv_layer,
o_proj,
mlp,
dev,
max_seq_len,
input_layernorm=None,
ln_attn=None,
ln_mlp=None,
new_decoder_arch=True,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 8 if new_decoder_arch else 0
self.hidden_size = hidden_size
self.new_decoder_arch = new_decoder_arch
if new_decoder_arch:
attention_shapes = None
else:
attention_shapes = self._get_attention_shapes(
n_heads, max_seq_len, self.hidden_size // n_heads
)
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(
hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
attention_shapes=attention_shapes,
).to(dev)
if new_decoder_arch:
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
else:
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
self.device = dev
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (
batch_size,
1,
max_seq_len,
head_dim,
),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (
batch_size,
1,
head_dim // 8,
max_seq_len,
8,
),
"xqkv_view": (n_heads + 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
"xq_view": (n_heads, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim),
}
return self.attention_shapes
def forward(
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
if self.new_decoder_arch:
layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
layernorm_out = self.input_layernorm(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=layernorm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True,
)
h_attn = hidden_states.to(attn_output.device) + attn_output
if self.new_decoder_arch:
h_mlp = self.mlp.forward(mlp_layernorm_out)
else:
h_mlp = self.mlp.forward(layernorm_out)
out = h_attn + h_mlp
return out, None, past_key_value
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment