"vscode:/vscode.git/clone" did not exist on "0d2a151ec81344e81fd345f3e53edd65ff856d5b"
Unverified Commit 3b362c0d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Replace `QuantLlamaMLP` with `QuantFusedMLP` (#188)

parent 09c73fb2
......@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as OldAquilaForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class AquilaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -95,7 +95,7 @@ class AquilaFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantLlamaMLP(
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
......
......@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -95,7 +95,7 @@ class LlamaFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantLlamaMLP(
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
......
......@@ -8,7 +8,7 @@ from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM
)
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM):
......@@ -95,7 +95,7 @@ class MistralFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantLlamaMLP(
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
......
......@@ -4,7 +4,7 @@ from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.mlp import QuantLlamaMLP
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class YiAWQForCausalLM(BaseAWQForCausalLM):
......@@ -90,7 +90,7 @@ class YiFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantLlamaMLP(
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
......
......@@ -131,7 +131,16 @@ class QuantAttentionFused(nn.Module):
elif bsz < self.cache_batch_size:
self.cache.decrease_batch_size(bsz)
self.cache_batch_size = bsz
# Always reset to 0
self.start_pos = 0
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
if"past_key_value" in kwargs and kwargs["past_key_value"] is None:
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
......@@ -3,15 +3,17 @@ import awq_inference_engine
import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantLlamaMLP(nn.Module):
class QuantFusedMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj
up_proj,
activation = F.silu,
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
......@@ -32,6 +34,8 @@ class QuantLlamaMLP(nn.Module):
self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8
self.activation = activation
def forward(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
......@@ -49,8 +53,22 @@ class QuantLlamaMLP(nn.Module):
self.up_proj_qzeros,
self.group_size,
)
x = F.silu(gate_output) * up_output
x = self.activation(gate_output) * up_output
x = x.reshape(out_shape)
x = self.down_proj(x)
return x
\ No newline at end of file
return x
class QuantLlamaMLP(QuantFusedMLP):
r"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead.
"""
def __init__(
self,
gate_proj,
down_proj,
up_proj
):
super().__init__(gate_proj, down_proj, up_proj)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment