Commit b49168f8 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add bigcode fused modules (WIP)

parent dff3d157
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeBlock from transformers.models.gpt_bigcode.modeling_gpt_bigcode import GPTBigCodeForCausalLM, GPTBigCodeBlock as OldGptBigCodeBlock
class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "GPTBigCodeBlock" layer_type = "GPTBigCodeBlock"
max_new_tokens_key = "n_positions" max_new_tokens_key = "n_positions"
@staticmethod
def fuse_layers(model: GPTBigCodeForCausalLM, quant_config:dict):
# TODO: Fix single_query_attention
pass
# fuser = GptBigCodeFuser(model)
# fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: GPTBigCodeForCausalLM): def get_model_layers(model: GPTBigCodeForCausalLM):
return model.transformer.h return model.transformer.h
@staticmethod @staticmethod
def get_act_for_scaling(module: GPTBigCodeBlock): def get_act_for_scaling(module: OldGptBigCodeBlock):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="mlp.act", scale_name="mlp.act",
...@@ -24,7 +32,7 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): ...@@ -24,7 +32,7 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
model.transformer.drop = model.transformer.drop.to(device) model.transformer.drop = model.transformer.drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module:GPTBigCodeBlock, input_feat, module_kwargs): def get_layers_for_scaling(module:OldGptBigCodeBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -52,3 +60,41 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM): ...@@ -52,3 +60,41 @@ class GptBigCodeAWQForCausalLM(BaseAWQForCausalLM):
)) ))
return layers return layers
from typing import List, Tuple
from awq.modules.fused.block import GptBigCodeBlock
from awq.modules.fused.model import GptBigCodeModel
class GptBigCodeFuser:
def __init__(self, model: GPTBigCodeForCausalLM):
self.model = model
self.blocks: List[Tuple[str, OldGptBigCodeBlock]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, OldGptBigCodeBlock)
]
def fuse_transformer(self):
blocks = []
module: OldGptBigCodeBlock
for module in self.model.transformer.h:
blocks.append(GptBigCodeBlock(
self.model.config.n_embd,
self.model.config.n_head,
module.attn.c_attn,
module.attn.c_proj,
module.mlp,
module.ln_1,
module.ln_2,
next(iter(module.state_dict().values())).device,
self.model.config.n_positions
))
self.model.transformer = GptBigCodeModel(
self.model.config.vocab_size,
blocks,
self.model.transformer.wte,
self.model.transformer.wpe,
self.model.transformer.ln_f,
)
...@@ -113,3 +113,58 @@ class FalconDecoderLayer(nn.Module): ...@@ -113,3 +113,58 @@ class FalconDecoderLayer(nn.Module):
out = h_attn + h_mlp out = h_attn + h_mlp
return out, None, past_key_value return out, None, past_key_value
class GptBigCodeBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 0
self.hidden_size = hidden_size
self.norm_1 = norm_1
attention_shapes = self._get_attention_shapes(
max_seq_len, self.hidden_size // n_heads
)
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, attention_shapes=attention_shapes
).to(dev)
self.norm_2 = norm_2
self.ffn = mlp.to(dev)
def _get_attention_shapes(self, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
return {
# following fastertransformer definition
"cache_v": (batch_size, self.n_heads, max_seq_len, head_dim,),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, self.n_heads, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (-1, self.n_heads+2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, :, 2],
"xq_view": (1, head_dim),
"xk_view": (1, head_dim),
"xv_view": (1, head_dim),
"xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (1, head_dim),
"single_xk_view": (self.n_heads, head_dim),
"single_xv_view": (self.n_heads, head_dim)
}
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
)
h = hidden_states + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
\ No newline at end of file
import torch import torch
import torch.nn as nn import torch.nn as nn
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, GptBigCodeBlock
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPastAndCrossAttentions
class MPTModel(nn.Module): class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f): def __init__(self, vocab_size, blocks, wte, norm_f):
...@@ -63,3 +63,57 @@ class FalconModel(nn.Module): ...@@ -63,3 +63,57 @@ class FalconModel(nn.Module):
h = self.ln_f(h) h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
class GptBigCodeModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, wpe, ln_f):
super().__init__()
self.vocab_size = vocab_size
self.wte = wte
self.wpe = wpe
self.blocks: list[GptBigCodeBlock] = nn.ModuleList(blocks)
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
@torch.inference_mode()
def forward(self, input_ids, position_ids=None, token_type_ids=None, attention_mask=None, past_key_values=None, is_causal=None, *args, **kwargs):
_bsz, seqlen = input_ids.shape
if past_key_values is None:
past_length = 0
past_key_values = tuple([None] * len(self.blocks))
else:
past_length = past_key_values[0].size(-2)
if attention_mask is not None and len(attention_mask.shape) == 2 and position_ids is None:
# create position_ids on the fly for batch generation
position_ids = attention_mask.long().cumsum(-1) - 1
position_ids.masked_fill_(attention_mask == 0, 1)
if past_length > 0:
position_ids = position_ids[:, past_length : input_ids.size()[-1] + past_length :]
elif position_ids is None:
position_ids = torch.arange(past_length, input_ids.size()[-1] + past_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).view(-1, input_ids.size()[-1])
input_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
h = input_embeds + position_embeds
if token_type_ids is not None:
token_type_embeds = self.wte(token_type_ids)
h = h + token_type_embeds
mask = None
if seqlen > 1:
mask = torch.full(
(1, 1, seqlen, seqlen), float("-inf"), device=input_ids.device
)
mask = torch.triu(mask, diagonal=self.blocks[0].attn.start_pos + 1).type_as(h)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.ln_f(h)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=(), cross_attentions=()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment