Commit 63346c34 authored by Casper Hansen's avatar Casper Hansen
Browse files

Integrate fused modules into AWQ model loading

parent 870a9dc9
...@@ -33,9 +33,9 @@ class AutoAWQForCausalLM: ...@@ -33,9 +33,9 @@ class AutoAWQForCausalLM:
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None, def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM: device='balanced', trust_remote_code=True, fuse_layers=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers
) )
\ No newline at end of file
...@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None, def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True, device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True): safetensors=False, is_quantized=True, fuse_layers=False):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"] ignore_patterns = ["*msgpack*", "*h5*"]
...@@ -298,6 +298,9 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -298,6 +298,9 @@ class BaseAWQForCausalLM(nn.Module):
if is_quantized: if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type]) model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
if fuse_layers:
self.fuse_layers(model)
else: else:
# If not quantized, must load with AutoModelForCausalLM # If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map( device_map = infer_auto_device_map(
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer" layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(awq_model):
make_quant_attn(awq_model, awq_model.device)
make_quant_norm(awq_model)
make_fused_mlp(awq_model)
@staticmethod @staticmethod
def get_model_layers(model: LlamaForCausalLM): def get_model_layers(model: LlamaForCausalLM):
return model.model.layers return model.model.layers
......
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len" max_new_tokens_key = "max_seq_len"
@staticmethod
def fuse_layers(awq_model):
make_fused_mlp(awq_model)
@staticmethod @staticmethod
def get_model_layers(model): def get_model_layers(model):
return model.transformer.blocks return model.transformer.blocks
......
...@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP ...@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP
import awq_inference_engine import awq_inference_engine
class QuantMPTMLP(nn.Module):
def __init__(
self,
up_proj,
act,
down_proj
):
super().__init__()
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)
return self.down_proj(self.act(x))
class QuantLlamaMLP(nn.Module): class QuantLlamaMLP(nn.Module):
...@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''): ...@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''):
""" """
if isinstance(m, LlamaMLP): if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj) return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
elif "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children(): for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}") child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP): if isinstance(child, QuantLlamaMLP):
setattr(m, name, child) setattr(m, name, child)
return m elif isinstance(child, QuantMPTMLP):
setattr(m, name, child)
return m
\ No newline at end of file
...@@ -116,13 +116,6 @@ if __name__ == '__main__': ...@@ -116,13 +116,6 @@ if __name__ == '__main__':
else: else:
stream_generator = StreamGenerator stream_generator = StreamGenerator
# Optimize AWQ quantized model
if args.precision == "W4A16" and isinstance(model, LlamaAWQForCausalLM):
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
make_quant_attn(model.model, args.device)
make_quant_norm(model.model)
make_fused_mlp(model.model)
model_prompter = get_prompter(model, args.model_path) model_prompter = get_prompter(model, args.model_path)
stop_token_ids = get_stop_token_ids(model, args.model_path) stop_token_ids = get_stop_token_ids(model, args.model_path)
count = 0 count = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment