"git@developer.sourcefind.cn:OpenDAS/torchaudio.git" did not exist on "93cc6da7771baf4c7beae0b6373efbe9dc16485d"
Unverified Commit 3b362c0d authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Replace `QuantLlamaMLP` with `QuantFusedMLP` (#188)

parent 09c73fb2
...@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import ( ...@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldAquilaDecoderLayer, LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as OldAquilaForCausalLM LlamaForCausalLM as OldAquilaForCausalLM
) )
from awq.modules.fused.mlp import QuantLlamaMLP from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class AquilaAWQForCausalLM(BaseAWQForCausalLM): class AquilaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -95,7 +95,7 @@ class AquilaFuser: ...@@ -95,7 +95,7 @@ class AquilaFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantLlamaMLP( mlp = QuantFusedMLP(
module.mlp.gate_proj, module.mlp.gate_proj,
module.mlp.down_proj, module.mlp.down_proj,
module.mlp.up_proj module.mlp.up_proj
......
...@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import ( ...@@ -8,7 +8,7 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer, LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM LlamaForCausalLM as OldLlamaForCausalLM
) )
from awq.modules.fused.mlp import QuantLlamaMLP from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -95,7 +95,7 @@ class LlamaFuser: ...@@ -95,7 +95,7 @@ class LlamaFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantLlamaMLP( mlp = QuantFusedMLP(
module.mlp.gate_proj, module.mlp.gate_proj,
module.mlp.down_proj, module.mlp.down_proj,
module.mlp.up_proj module.mlp.up_proj
......
...@@ -8,7 +8,7 @@ from transformers.models.mistral.modeling_mistral import ( ...@@ -8,7 +8,7 @@ from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer, MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM MistralForCausalLM as OldMistralForCausalLM
) )
from awq.modules.fused.mlp import QuantLlamaMLP from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM): class MistralAWQForCausalLM(BaseAWQForCausalLM):
...@@ -95,7 +95,7 @@ class MistralFuser: ...@@ -95,7 +95,7 @@ class MistralFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantLlamaMLP( mlp = QuantFusedMLP(
module.mlp.gate_proj, module.mlp.gate_proj,
module.mlp.down_proj, module.mlp.down_proj,
module.mlp.up_proj module.mlp.up_proj
......
...@@ -4,7 +4,7 @@ from .base import BaseAWQForCausalLM ...@@ -4,7 +4,7 @@ from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.mlp import QuantLlamaMLP from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class YiAWQForCausalLM(BaseAWQForCausalLM): class YiAWQForCausalLM(BaseAWQForCausalLM):
...@@ -90,7 +90,7 @@ class YiFuser: ...@@ -90,7 +90,7 @@ class YiFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantLlamaMLP( mlp = QuantFusedMLP(
module.mlp.gate_proj, module.mlp.gate_proj,
module.mlp.down_proj, module.mlp.down_proj,
module.mlp.up_proj module.mlp.up_proj
......
...@@ -131,7 +131,16 @@ class QuantAttentionFused(nn.Module): ...@@ -131,7 +131,16 @@ class QuantAttentionFused(nn.Module):
elif bsz < self.cache_batch_size: elif bsz < self.cache_batch_size:
self.cache.decrease_batch_size(bsz) self.cache.decrease_batch_size(bsz)
self.cache_batch_size = bsz self.cache_batch_size = bsz
# Always reset to 0
self.start_pos = 0
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
if"past_key_value" in kwargs and kwargs["past_key_value"] is None:
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
......
...@@ -3,15 +3,17 @@ import awq_inference_engine ...@@ -3,15 +3,17 @@ import awq_inference_engine
import torch.nn.functional as F import torch.nn.functional as F
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
class QuantLlamaMLP(nn.Module):
class QuantFusedMLP(nn.Module):
def __init__( def __init__(
self, self,
gate_proj, gate_proj,
down_proj, down_proj,
up_proj up_proj,
activation = F.silu,
): ):
super().__init__() super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight) self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales) self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
...@@ -32,6 +34,8 @@ class QuantLlamaMLP(nn.Module): ...@@ -32,6 +34,8 @@ class QuantLlamaMLP(nn.Module):
self.linear = awq_inference_engine.gemm_forward_cuda self.linear = awq_inference_engine.gemm_forward_cuda
self.group_size = 8 self.group_size = 8
self.activation = activation
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.intermediate_size,) out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
...@@ -49,8 +53,22 @@ class QuantLlamaMLP(nn.Module): ...@@ -49,8 +53,22 @@ class QuantLlamaMLP(nn.Module):
self.up_proj_qzeros, self.up_proj_qzeros,
self.group_size, self.group_size,
) )
x = F.silu(gate_output) * up_output x = self.activation(gate_output) * up_output
x = x.reshape(out_shape) x = x.reshape(out_shape)
x = self.down_proj(x) x = self.down_proj(x)
return x return x
\ No newline at end of file
class QuantLlamaMLP(QuantFusedMLP):
r"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead.
"""
def __init__(
self,
gate_proj,
down_proj,
up_proj
):
super().__init__(gate_proj, down_proj, up_proj)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment