Unverified Commit 8110e028 authored by Casper's avatar Casper Committed by GitHub
Browse files

Create fused LlamaLikeModel (#152)

parent 84a26861
## Reference from llama.py import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as AquilaDecoderLayer, LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as AquilaForCausalLM, LlamaForCausalLM as OldAquilaForCausalLM
LlamaAttention as AquilaAttention,
LlamaRMSNorm as AquilaRMSNorm,
LlamaMLP as AquilaMLP
) )
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class AquilaAWQForCausalLM(BaseAWQForCausalLM): class AquilaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "AquilaDecoderLayer" layer_type = "AquilaDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: AquilaForCausalLM): def fuse_layers(model: OldAquilaForCausalLM):
fuser = AquilaFuser(model) fuser = AquilaFuser(model)
fuser.fuse_attention() fuser.fuse_transformer()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model: AquilaForCausalLM): def get_model_layers(model: OldAquilaForCausalLM):
return model.model.layers return model.model.layers
@staticmethod @staticmethod
def get_act_for_scaling(module: AquilaDecoderLayer): def get_act_for_scaling(module: OldAquilaDecoderLayer):
return dict( return dict(
is_scalable=False is_scalable=False
) )
@staticmethod @staticmethod
def move_embed(model: AquilaForCausalLM, device: str): def move_embed(model: OldAquilaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -72,85 +73,57 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -72,85 +73,57 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class AquilaFuser: class AquilaFuser:
def __init__(self, model): def __init__(self, model: OldAquilaForCausalLM):
self.model = model self.model = model
self.attention_modules: List[Tuple[str, AquilaAttention]] = [ self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if "AquilaAttention".lower() in module.__class__.__name__.lower()
]
self.rmsnorm_modules: List[Tuple[str, AquilaRMSNorm]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
if "AquilaRMSNorm".lower() in module.__class__.__name__.lower() if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower()
] ]
self.mlp_modules: List[Tuple[str, AquilaMLP]] = [ def fuse_transformer(self):
(name, module) for name, module in self.model.named_modules() blocks = []
if "AquilaMLP".lower() in module.__class__.__name__.lower()
] module: OldAquilaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
def fuse_attention(self): device = next(iter(module.state_dict().values())).device
for name, module in self.attention_modules: qkv = fuse_qkv(
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module) module,
attn = QuantAttentionFused( module.self_attn.q_proj,
module.hidden_size, module.self_attn.k_proj,
module.num_heads, module.self_attn.v_proj
module.num_key_value_heads, )
qkv_layer, mlp = QuantLlamaMLP(
module.o_proj, module.mlp.gate_proj,
next(iter(qkv_layer.state_dict().values())).device, module.mlp.down_proj,
self.model.config.max_new_tokens module.mlp.up_proj
) )
set_module_name(self.model, name, attn) norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
def _fuse_qkv(self, module: AquilaAttention): module.input_layernorm.variance_epsilon
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
) )
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
if isinstance(qkv_layer, WQLinear_GEMV): self.model.model = LlamaLikeModel(
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) self.model.config.vocab_size,
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) blocks,
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) self.model.model.embed_tokens,
qkv_layer.split_k_iters = q_proj.split_k_iters self.model.model.norm,
else: )
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer" layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: LlamaForCausalLM): def fuse_layers(model: OldLlamaForCausalLM):
fuser = LlamaFuser(model) fuser = LlamaFuser(model)
fuser.fuse_attention() fuser.fuse_transformer()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model: LlamaForCausalLM): def get_model_layers(model: OldLlamaForCausalLM):
return model.model.layers return model.model.layers
@staticmethod @staticmethod
def get_act_for_scaling(module: LlamaDecoderLayer): def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict( return dict(
is_scalable=False is_scalable=False
) )
@staticmethod @staticmethod
def move_embed(model: LlamaForCausalLM, device: str): def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -65,86 +73,57 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -65,86 +73,57 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser: class LlamaFuser:
def __init__(self, model): def __init__(self, model: OldLlamaForCausalLM):
self.model = model self.model = model
self.attention_modules: List[Tuple[str, LlamaAttention]] = [ self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaAttention)
]
self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaRMSNorm) if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
] ]
self.mlp_modules: List[Tuple[str, LlamaMLP]] = [ def fuse_transformer(self):
(name, module) for name, module in self.model.named_modules() blocks = []
if isinstance(module, LlamaMLP)
] module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
def fuse_attention(self): device = next(iter(module.state_dict().values())).device
for name, module in self.attention_modules: qkv = fuse_qkv(
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module) module,
attn = QuantAttentionFused( module.self_attn.q_proj,
module.hidden_size, module.self_attn.k_proj,
module.num_heads, module.self_attn.v_proj
module.num_key_value_heads, )
qkv_layer, mlp = QuantLlamaMLP(
module.o_proj, module.mlp.gate_proj,
next(iter(qkv_layer.state_dict().values())).device, module.mlp.down_proj,
self.model.config.max_new_tokens module.mlp.up_proj
) )
set_module_name(self.model, name, attn) norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
def _fuse_qkv(self, module: LlamaAttention): module.input_layernorm.variance_epsilon
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
) )
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
if isinstance(qkv_layer, WQLinear_GEMV): self.model.model = LlamaLikeModel(
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) self.model.config.vocab_size,
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) blocks,
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) self.model.model.embed_tokens,
qkv_layer.split_k_iters = q_proj.split_k_iters self.model.model.norm,
else: )
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM): class MistralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MistralDecoderLayer" layer_type = "MistralDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: MistralForCausalLM): def fuse_layers(model: OldMistralForCausalLM):
fuser = MistralFuser(model) fuser = MistralFuser(model)
fuser.fuse_attention() fuser.fuse_transformer()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model: MistralForCausalLM): def get_model_layers(model: OldMistralForCausalLM):
return model.model.layers return model.model.layers
@staticmethod @staticmethod
def get_act_for_scaling(module: MistralDecoderLayer): def get_act_for_scaling(module: OldMistralDecoderLayer):
return dict( return dict(
is_scalable=False is_scalable=False
) )
@staticmethod @staticmethod
def move_embed(model: MistralForCausalLM, device: str): def move_embed(model: OldMistralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -65,86 +73,57 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM): ...@@ -65,86 +73,57 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
return layers return layers
import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP
class MistralFuser: class MistralFuser:
def __init__(self, model): def __init__(self, model: OldMistralForCausalLM):
self.model = model self.model = model
self.attention_modules: List[Tuple[str, MistralAttention]] = [ self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralAttention)
]
self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [
(name, module) for name, module in self.model.named_modules() (name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralRMSNorm) if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower()
] ]
self.mlp_modules: List[Tuple[str, MistralMLP]] = [ def fuse_transformer(self):
(name, module) for name, module in self.model.named_modules() blocks = []
if isinstance(module, MistralMLP)
] module: OldMistralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
def fuse_attention(self): device = next(iter(module.state_dict().values())).device
for name, module in self.attention_modules: qkv = fuse_qkv(
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module) module,
attn = QuantAttentionFused( module.self_attn.q_proj,
module.hidden_size, module.self_attn.k_proj,
module.num_heads, module.self_attn.v_proj
module.num_key_value_heads, )
qkv_layer, mlp = QuantLlamaMLP(
module.o_proj, module.mlp.gate_proj,
next(iter(qkv_layer.state_dict().values())).device, module.mlp.down_proj,
self.model.config.max_new_tokens module.mlp.up_proj
) )
set_module_name(self.model, name, attn) norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
def _fuse_qkv(self, module: MistralAttention): module.input_layernorm.variance_epsilon
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
) )
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
if isinstance(qkv_layer, WQLinear_GEMV): self.model.model = LlamaLikeModel(
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) self.model.config.vocab_size,
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) blocks,
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) self.model.model.embed_tokens,
qkv_layer.split_k_iters = q_proj.split_k_iters self.model.model.norm,
else: )
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
...@@ -123,17 +123,6 @@ class QuantAttentionFused(nn.Module): ...@@ -123,17 +123,6 @@ class QuantAttentionFused(nn.Module):
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
# Check if we are under transformers caching regime
has_past_key_value = kwargs is not None and "past_key_value" in kwargs and kwargs["past_key_value"] is not None
if has_past_key_value:
# In newest transformers version, when using caching the input hidden states do not consist of
# the last generated token only, but of the whole sequence - past-kvlength. We need to slice the last token
# and set `seqlen=1`
if seqlen > 1:
seqlen = 1
hidden_states = hidden_states[:, -1:]
if bsz != self.cache_batch_size: if bsz != self.cache_batch_size:
raise RuntimeError( raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. " f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
......
...@@ -2,6 +2,39 @@ import os ...@@ -2,6 +2,39 @@ import os
import torch.nn as nn import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
class LlamaLikeBlock(nn.Module):
"""
LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
)
h = hidden_states + attn_output
out = h + self.mlp.forward(self.norm_2(h))
return out, None, past_key_value
class MPTBlock(nn.Module): class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__() super().__init__()
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List from typing import List
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.utils.fused_utils import prepare_attention_mask, prepare_input_ids
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
class LlamaLikeModel(nn.Module):
"""
LlamaLikeModel is intended to be reused across models that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[LlamaLikeBlock] = blocks
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
h = self.embedding(input_ids)
mask = prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h
)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
class MPTModel(nn.Module): class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f): def __init__(self, vocab_size, blocks, wte, norm_f):
...@@ -13,18 +50,24 @@ class MPTModel(nn.Module): ...@@ -13,18 +50,24 @@ class MPTModel(nn.Module):
self.norm_f = norm_f self.norm_f = norm_f
self.attn_uses_sequence_id = False self.attn_uses_sequence_id = False
self.prefix_lm = False self.prefix_lm = False
self.last_forward_num_tokens = 0
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
h = self.wte(input_ids) h = self.wte(input_ids)
mask = None mask = prepare_attention_mask(
if seqlen > 1: seqlen=seqlen,
mask = torch.full( start_pos=self.blocks[0].attn.start_pos,
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device device=input_ids.device,
type_as=h
) )
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks: for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
...@@ -41,23 +84,24 @@ class FalconModel(nn.Module): ...@@ -41,23 +84,24 @@ class FalconModel(nn.Module):
self.ln_f = ln_f self.ln_f = ln_f
self.attn_uses_sequence_id = False self.attn_uses_sequence_id = False
self.prefix_lm = False self.prefix_lm = False
self.last_forward_num_tokens = 0
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
# NOTE: falcon input ids contain full context input_ids, self.last_forward_num_tokens = prepare_input_ids(
# after context is processed, slice to latest token input_ids,
if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1: self.last_forward_num_tokens
input_ids = input_ids[:, self.blocks[0].attn.start_pos:] )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
h = self.word_embeddings(input_ids) h = self.word_embeddings(input_ids)
mask = None mask = prepare_attention_mask(
if seqlen > 1: seqlen=seqlen,
mask = torch.full( start_pos=self.blocks[0].attn.start_pos,
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device device=input_ids.device,
type_as=h
) )
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks: for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
......
import torch
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
# NOTE: from transformers 4.35.0, input_ids includes full context during decoding
num_input_tokens = input_ids.shape[-1]
num_new_tokens = num_input_tokens
if num_input_tokens != 1:
num_new_tokens = num_input_tokens - last_forward_num_tokens
# after context is processed, slice to latest token
if num_new_tokens in [0,1]:
input_ids = input_ids[:, -1:]
return input_ids, last_forward_num_tokens + num_new_tokens
def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor):
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=device
)
mask = torch.triu(mask, diagonal=start_pos+ 1).type_as(type_as)
return mask
def fuse_qkv(module, q_proj, k_proj, v_proj):
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
)
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim): def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None: if attention_shapes is not None:
......
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ" quant_path = "TheBloke/zephyr-7B-beta-AWQ"
# Load model # Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True) model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
...@@ -10,11 +10,11 @@ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) ...@@ -10,11 +10,11 @@ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens # Convert prompt to tokens
prompt_template = """\ prompt_template = """\
<|im_start|>system <|system|>
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|> </s>
<|im_start|>user <|user|>
{prompt}<|im_end|> {prompt}</s>
<|im_start|>assistant""" <|assistant|>"""
prompt = "You're standing on the surface of the Earth. "\ prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\ "You walk one mile south, one mile west and one mile north. "\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment