Unverified Commit 8110e028 authored by Casper's avatar Casper Committed by GitHub
Browse files

Create fused LlamaLikeModel (#152)

parent 84a26861
## Reference from llama.py
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as AquilaDecoderLayer,
LlamaForCausalLM as AquilaForCausalLM,
LlamaAttention as AquilaAttention,
LlamaRMSNorm as AquilaRMSNorm,
LlamaMLP as AquilaMLP
LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as OldAquilaForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class AquilaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "AquilaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: AquilaForCausalLM):
def fuse_layers(model: OldAquilaForCausalLM):
fuser = AquilaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: AquilaForCausalLM):
def get_model_layers(model: OldAquilaForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: AquilaDecoderLayer):
def get_act_for_scaling(module: OldAquilaDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: AquilaForCausalLM, device: str):
def move_embed(model: OldAquilaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: AquilaDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldAquilaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
......@@ -72,85 +73,57 @@ class AquilaAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class AquilaFuser:
def __init__(self, model):
def __init__(self, model: OldAquilaForCausalLM):
self.model = model
self.attention_modules: List[Tuple[str, AquilaAttention]] = [
(name, module) for name, module in self.model.named_modules()
if "AquilaAttention".lower() in module.__class__.__name__.lower()
]
self.rmsnorm_modules: List[Tuple[str, AquilaRMSNorm]] = [
self.aquila_blocks: List[Tuple[str, OldAquilaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if "AquilaRMSNorm".lower() in module.__class__.__name__.lower()
if 'AquilaDecoderLayer'.lower() in module.__class__.__name__.lower()
]
self.mlp_modules: List[Tuple[str, AquilaMLP]] = [
(name, module) for name, module in self.model.named_modules()
if "AquilaMLP".lower() in module.__class__.__name__.lower()
]
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
def fuse_transformer(self):
blocks = []
module: OldAquilaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantLlamaMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
set_module_name(self.model, name, attn)
def _fuse_qkv(self, module: AquilaAttention):
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: LlamaForCausalLM):
def fuse_layers(model: OldLlamaForCausalLM):
fuser = LlamaFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: LlamaForCausalLM):
def get_model_layers(model: OldLlamaForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: LlamaDecoderLayer):
def get_act_for_scaling(module: OldLlamaDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: LlamaForCausalLM, device: str):
def move_embed(model: OldLlamaForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: LlamaDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
......@@ -65,86 +73,57 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
def __init__(self, model):
def __init__(self, model: OldLlamaForCausalLM):
self.model = model
self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaAttention)
]
self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
self.llama_blocks: List[Tuple[str, OldLlamaDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaRMSNorm)
if 'LlamaDecoderLayer'.lower() in module.__class__.__name__.lower()
]
self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaMLP)
]
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
def fuse_transformer(self):
blocks = []
module: OldLlamaDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantLlamaMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
set_module_name(self.model, name, attn)
def _fuse_qkv(self, module: LlamaAttention):
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from transformers.models.mistral.modeling_mistral import MistralDecoderLayer, MistralForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MistralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: MistralForCausalLM):
def fuse_layers(model: OldMistralForCausalLM):
fuser = MistralFuser(model)
fuser.fuse_attention()
fuser.fuse_rmsnorm()
fuser.fuse_mlp()
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: MistralForCausalLM):
def get_model_layers(model: OldMistralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: MistralDecoderLayer):
def get_act_for_scaling(module: OldMistralDecoderLayer):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: MistralForCausalLM, device: str):
def move_embed(model: OldMistralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: MistralDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
......@@ -65,86 +73,57 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
return layers
import torch
from typing import List, Tuple, Union
from awq.utils.utils import set_module_name
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.attn import QuantAttentionFused
from awq.modules.fused.norm import FasterTransformerRMSNorm
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from transformers.models.mistral.modeling_mistral import MistralAttention, MistralRMSNorm, MistralMLP
class MistralFuser:
def __init__(self, model):
def __init__(self, model: OldMistralForCausalLM):
self.model = model
self.attention_modules: List[Tuple[str, MistralAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralAttention)
]
self.rmsnorm_modules: List[Tuple[str, MistralRMSNorm]] = [
self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralRMSNorm)
if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower()
]
self.mlp_modules: List[Tuple[str, MistralMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MistralMLP)
]
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: Union[WQLinear_GEMM, WQLinear_GEMV] = self._fuse_qkv(module)
attn = QuantAttentionFused(
module.hidden_size,
module.num_heads,
module.num_key_value_heads,
qkv_layer,
module.o_proj,
next(iter(qkv_layer.state_dict().values())).device,
self.model.config.max_new_tokens
def fuse_transformer(self):
blocks = []
module: OldMistralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantLlamaMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
set_module_name(self.model, name, attn)
def _fuse_qkv(self, module: MistralAttention):
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FasterTransformerRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
......@@ -123,17 +123,6 @@ class QuantAttentionFused(nn.Module):
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
bsz, seqlen, _ = hidden_states.shape
# Check if we are under transformers caching regime
has_past_key_value = kwargs is not None and "past_key_value" in kwargs and kwargs["past_key_value"] is not None
if has_past_key_value:
# In newest transformers version, when using caching the input hidden states do not consist of
# the last generated token only, but of the whole sequence - past-kvlength. We need to slice the last token
# and set `seqlen=1`
if seqlen > 1:
seqlen = 1
hidden_states = hidden_states[:, -1:]
if bsz != self.cache_batch_size:
raise RuntimeError(
f"Batch size is incorrectly set - input batch size {bsz}, kv-cache batch size {self.cache_batch_size}. "
......
......@@ -2,6 +2,39 @@ import os
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class LlamaLikeBlock(nn.Module):
"""
LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
)
h = hidden_states + attn_output
out = h + self.mlp.forward(self.norm_2(h))
return out, None, past_key_value
class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
......
import torch
import torch.nn as nn
from typing import List
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer
from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.utils.fused_utils import prepare_attention_mask, prepare_input_ids
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
class LlamaLikeModel(nn.Module):
"""
LlamaLikeModel is intended to be reused across models that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[LlamaLikeBlock] = blocks
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
h = self.embedding(input_ids)
mask = prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h
)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f):
......@@ -13,18 +50,24 @@ class MPTModel(nn.Module):
self.norm_f = norm_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
h = self.wte(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
mask = prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
......@@ -41,23 +84,24 @@ class FalconModel(nn.Module):
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
# NOTE: falcon input ids contain full context
# after context is processed, slice to latest token
if self.blocks[0].attn.start_pos != 0 and input_ids.shape[-1] != 1:
input_ids = input_ids[:, self.blocks[0].attn.start_pos:]
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
h = self.word_embeddings(input_ids)
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
mask = prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
......
import torch
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
# NOTE: from transformers 4.35.0, input_ids includes full context during decoding
num_input_tokens = input_ids.shape[-1]
num_new_tokens = num_input_tokens
if num_input_tokens != 1:
num_new_tokens = num_input_tokens - last_forward_num_tokens
# after context is processed, slice to latest token
if num_new_tokens in [0,1]:
input_ids = input_ids[:, -1:]
return input_ids, last_forward_num_tokens + num_new_tokens
def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor):
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=device
)
mask = torch.triu(mask, diagonal=start_pos+ 1).type_as(type_as)
return mask
def fuse_qkv(module, q_proj, k_proj, v_proj):
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
q_linear = WQLinear_GEMM
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device
)
if isinstance(qkv_layer, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def get_attention_shapes(attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim):
if attention_shapes is not None:
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-OpenOrca-AWQ"
quant_path = "TheBloke/zephyr-7B-beta-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True)
......@@ -10,11 +10,11 @@ streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
<|im_start|>system
You are MistralOrca, a large language model trained by Alignment Lab AI. Write out your reasoning step-by-step to be sure you get the right answers!<|im_end|>
<|im_start|>user
{prompt}<|im_end|>
<|im_start|>assistant"""
<|system|>
</s>
<|user|>
{prompt}</s>
<|assistant|>"""
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment