Unverified Commit 783afe50 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #18 from casper-hansen/refactor_fused

Refactor fused modules
parents 560fbe59 0aa4a596
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -7,10 +6,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -7,10 +6,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(awq_model): def fuse_layers(model: LlamaForCausalLM):
make_quant_attn(awq_model, awq_model.device) fuser = LlamaFuser(model)
make_quant_norm(awq_model) fuser.fuse_attention()
make_fused_mlp(awq_model) fuser.fuse_rmsnorm()
fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model: LlamaForCausalLM): def get_model_layers(model: LlamaForCausalLM):
...@@ -64,3 +64,77 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM): ...@@ -64,3 +64,77 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
)) ))
return layers return layers
import torch
from typing import List, Tuple
from awq.quantize.qmodule import WQLinear
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantLlamaMLP
from awq.modules.fused_norm import FTLlamaRMSNorm
from awq.modules.fused_attn import QuantLlamaAttention
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRMSNorm, LlamaMLP
class LlamaFuser:
def __init__(self, model):
self.model = model
self.attention_modules: List[Tuple[str, LlamaAttention]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaAttention)
]
self.rmsnorm_modules: List[Tuple[str, LlamaRMSNorm]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaRMSNorm)
]
self.mlp_modules: List[Tuple[str, LlamaMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, LlamaMLP)
]
def fuse_attention(self):
for name, module in self.attention_modules:
qkv_layer: WQLinear = self._fuse_qkv(module)
attn = QuantLlamaAttention(
module.hidden_size,
module.num_heads,
qkv_layer,
module.o_proj,
qkv_layer.qweight.device,
self.model.config.max_new_tokens
)
set_module_name(self.model, name, attn)
def _fuse_qkv(self, module: LlamaAttention):
# get qkv and bias
q_proj, k_proj, v_proj = module.q_proj, module.k_proj, module.v_proj
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
# create module
qkv_layer = WQLinear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
q_proj.qweight.device
)
# replace buffers with real weights
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
qkv_layer.bias = bias
return qkv_layer
def fuse_rmsnorm(self):
for name, module in self.rmsnorm_modules:
norm = FTLlamaRMSNorm(module.weight, module.variance_epsilon)
set_module_name(self.model, name, norm)
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantLlamaMLP(module.gate_proj, module.down_proj, module.up_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp from transformers.models.mpt.modeling_mpt import MptBlock, MptForCausalLM, MptMLP
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len" max_new_tokens_key = "max_seq_len"
@staticmethod @staticmethod
def fuse_layers(awq_model): def fuse_layers(model: MptForCausalLM):
make_fused_mlp(awq_model) fuser = MptFuser(model)
fuser.fuse_mlp()
@staticmethod @staticmethod
def get_model_layers(model): def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks return model.transformer.blocks
@staticmethod @staticmethod
def get_act_for_scaling(module): def get_act_for_scaling(module: MptBlock):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="ffn.act", scale_name="ffn.act",
...@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
) )
@staticmethod @staticmethod
def move_embed(model, device): def move_embed(model: MptForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device) model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs): def get_layers_for_scaling(module: MptBlock, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
...@@ -63,3 +64,27 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -63,3 +64,27 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
)) ))
return layers return layers
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused_mlp import QuantMPTMLP
class MptFuser:
def __init__(self, model):
self.model = model
self.mlp_modules: List[Tuple[str, MptMLP]] = [
(name, module) for name, module in self.model.named_modules()
if isinstance(module, MptMLP)
]
def fuse_attention(self):
pass
def fuse_layernorm(self):
pass
def fuse_mlp(self):
for name, module in self.mlp_modules:
mlp = QuantMPTMLP(module.up_proj, module.act, module.down_proj)
set_module_name(self.model, name, mlp)
\ No newline at end of file
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import functional as F
from transformers.models.llama.modeling_llama import LlamaAttention, LlamaRotaryEmbedding, apply_rotary_pos_emb
from awq.quantize.qmodule import WQLinear
import awq_inference_engine import awq_inference_engine
from torch.nn import functional as F
class QuantLlamaRotaryEmbedding(nn.Module): class QuantLlamaRotaryEmbedding(nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
...@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module): ...@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module):
num_heads, num_heads,
qkv_proj, qkv_proj,
o_proj, o_proj,
dev dev,
max_new_tokens
): ):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
...@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module): ...@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module):
f" and `num_heads`: {num_heads}).") f" and `num_heads`: {num_heads}).")
self.qkv_proj = qkv_proj self.qkv_proj = qkv_proj
self.o_proj = o_proj self.o_proj = o_proj
self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=2048, device = dev) self.rotary_emb = QuantLlamaRotaryEmbedding(self.head_dim, max_position_embeddings=max_new_tokens, device = dev)
def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False): def forward(self, hidden_states, past_key_value=None, attention_mask=None, position_ids=None, output_attentions=False, use_cache=False):
"""Input shape: Batch x Time x Channel""" """Input shape: Batch x Time x Channel"""
...@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module): ...@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module):
if past_key_value is not None: if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2] kv_seq_len += past_key_value[0].shape[-2]
value_states = value_states.to("cuda:0") value_states = value_states.to(key_states.device)
if past_key_value is not None: if past_key_value is not None:
# reuse k, v, self_attention # reuse k, v, self_attention
...@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module): ...@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, None, past_key_value return attn_output, None, past_key_value
def make_quant_attn(model, dev):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaAttention):
continue
q_proj = m.q_proj
k_proj = m.k_proj
v_proj = m.v_proj
qweights = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
g_idx = None
bias = torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0) if q_proj.bias is not None else None
qkv_layer = WQLinear(q_proj.w_bit, q_proj.group_size, q_proj.in_features, q_proj.out_features + k_proj.out_features + v_proj.out_features, q_proj.bias is not None, q_proj.qweight.device)
qkv_layer.qweight = qweights
qkv_layer.qzeros = qzeros
qkv_layer.scales = scales
qkv_layer.bias = bias
attn = QuantLlamaAttention(m.hidden_size, m.num_heads, qkv_layer, m.o_proj, dev)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, attn)
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from torch.cuda.amp import custom_bwd, custom_fwd
from transformers.models.llama.modeling_llama import LlamaMLP
import awq_inference_engine import awq_inference_engine
import torch.nn.functional as F
class QuantMPTMLP(nn.Module): class QuantMPTMLP(nn.Module):
def __init__( def __init__(
...@@ -67,25 +63,3 @@ class QuantLlamaMLP(nn.Module): ...@@ -67,25 +63,3 @@ class QuantLlamaMLP(nn.Module):
c = gate_output * up_output c = gate_output * up_output
c = c.reshape(out_shape) c = c.reshape(out_shape)
return c return c
def make_fused_mlp(m, parent_name=''):
if not hasattr(make_fused_mlp, "called"):
make_fused_mlp.called = True
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
elif "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)
elif isinstance(child, QuantMPTMLP):
setattr(m, name, child)
return m
\ No newline at end of file
import torch import torch
from torch import nn from torch import nn
from transformers.models.llama.modeling_llama import LlamaRMSNorm
import awq_inference_engine import awq_inference_engine
class FTLlamaRMSNorm(nn.Module): class FTLlamaRMSNorm(nn.Module):
...@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module): ...@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module):
output = torch.empty_like(x) output = torch.empty_like(x)
awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon) awq_inference_engine.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output return output
def make_quant_norm(model):
"""
Replace all LlamaRMSNorm modules with FTLlamaRMSNorm modules
"""
for name, m in model.named_modules():
if not isinstance(m, LlamaRMSNorm):
continue
norm = FTLlamaRMSNorm(m.weight, m.variance_epsilon)
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, norm)
...@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map): ...@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map):
model.hf_device_map = device_map model.hf_device_map = device_map
return model return model
def set_module_name(model, name, value):
if '.' in name:
parent_name = name.rsplit('.', 1)[0]
child_name = name[len(parent_name) + 1:]
parent = model.get_submodule(parent_name)
else:
parent_name = ''
parent = model
child_name = name
setattr(parent, child_name, value)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment