import os import torch.nn as nn from awq.modules.fused.attn import QuantAttentionFused class MixtralBlock(nn.Module): def __init__( self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, moe, norm_1, norm_2, dev, max_seq_len, rope_theta, ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.hidden_size = hidden_size self.norm_1 = norm_1.to(dev) self.attn = QuantAttentionFused( self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta, ).to(dev) self.norm_2 = norm_2.to(dev) self.moe = moe self.device = dev def forward( self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None, ): norm_out = self.norm_1(hidden_states) attn_output, _, past_key_value = self.attn.forward( hidden_states=norm_out, past_key_value=past_key_value, attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output out = self.moe.forward(self.norm_2(h)) out = h + out return out, None, past_key_value class LlamaLikeBlock(nn.Module): """ LlamaLikeBlock is intended to be reused across blocks that have an architecture that closely resembles Llama, e.g. Mistral and Aquila. """ def __init__( self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, mlp, norm_1, norm_2, dev, max_seq_len, rope_theta=10000, use_alibi=False, ): super().__init__() self.n_heads = n_heads self.n_kv_heads = n_kv_heads self.hidden_size = hidden_size self.norm_1 = norm_1.to(dev) self.attn = QuantAttentionFused( self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta, ).to(dev) self.norm_2 = norm_2.to(dev) self.mlp = mlp.to(dev) self.device = dev def forward( self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None, ): norm_out = self.norm_1(hidden_states) attn_output, _, past_key_value = self.attn.forward( hidden_states=norm_out, past_key_value=past_key_value, attention_mask=attention_mask, ) h = hidden_states.to(attn_output.device) + attn_output out = h + self.mlp.forward(self.norm_2(h)) return out, None, past_key_value class MPTBlock(nn.Module): def __init__( self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len, ): super().__init__() self.n_heads = n_heads self.n_kv_heads = 0 self.hidden_size = hidden_size self.norm_1 = norm_1 self.attn = QuantAttentionFused( hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=True, ).to(dev) self.norm_2 = norm_2 self.ffn = mpt_mlp.to(dev) self.device = dev def forward( self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None, ): norm_out = self.norm_1(hidden_states) attn_output, _, past_key_value = self.attn.forward( hidden_states=norm_out, past_key_value=past_key_value, attention_mask=attention_mask, position_ids=None, output_attentions=False, use_cache=True, ) h = hidden_states.to(attn_output.device) + attn_output out = h + self.ffn.forward(self.norm_2(h)) return out, None, past_key_value class FalconDecoderLayer(nn.Module): def __init__( self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True, ): super().__init__() self.n_heads = n_heads self.n_kv_heads = 8 if new_decoder_arch else 0 self.hidden_size = hidden_size self.new_decoder_arch = new_decoder_arch if new_decoder_arch: attention_shapes = None else: attention_shapes = self._get_attention_shapes( n_heads, max_seq_len, self.hidden_size // n_heads ) # TODO: Falcon has ALiBi implemented but which model uses it? self.attn = QuantAttentionFused( hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, dev=dev, max_seq_len=max_seq_len, use_alibi=False, attention_shapes=attention_shapes, ).to(dev) if new_decoder_arch: self.ln_attn = ln_attn # before attention self.ln_mlp = ln_mlp # before mlp else: self.input_layernorm = input_layernorm # before attention self.mlp = mlp self.device = dev def _get_attention_shapes(self, n_heads, max_seq_len, head_dim): batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.attention_shapes = { # following fastertransformer definition "cache_v": ( batch_size, 1, max_seq_len, head_dim, ), # 8: pack 8 fp16 in FT, if fp32 then use 4 "cache_k": ( batch_size, 1, head_dim // 8, max_seq_len, 8, ), "xqkv_view": (n_heads + 2, head_dim), "xq_slice": lambda xqkv: xqkv[:, :, :-2], "xk_slice": lambda xqkv: xqkv[:, :, [-2]], "xv_slice": lambda xqkv: xqkv[:, :, [-1]], "xq_view": (n_heads, head_dim), "xk_view": (1, head_dim), "xv_view": (1, head_dim), "xk_reshape": (1, head_dim // 8, 8), "single_xq_view": (n_heads, head_dim), "single_xk_view": (1, head_dim), "single_xv_view": (1, head_dim), } return self.attention_shapes def forward( self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None, ): if self.new_decoder_arch: layernorm_out = self.ln_attn(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states) else: layernorm_out = self.input_layernorm(hidden_states) attn_output, _, past_key_value = self.attn.forward( hidden_states=layernorm_out, past_key_value=past_key_value, attention_mask=attention_mask, position_ids=None, output_attentions=False, use_cache=True, ) h_attn = hidden_states.to(attn_output.device) + attn_output if self.new_decoder_arch: h_mlp = self.mlp.forward(mlp_layernorm_out) else: h_mlp = self.mlp.forward(layernorm_out) out = h_attn + h_mlp return out, None, past_key_value