"src/vscode:/vscode.git/clone" did not exist on "ef3844d3a83583f36d0166be6753d062b3cbd7dc"
Commit 63346c34 authored by Casper Hansen's avatar Casper Hansen
Browse files

Integrate fused modules into AWQ model loading

parent 870a9dc9
......@@ -33,9 +33,9 @@ class AutoAWQForCausalLM:
@classmethod
def from_quantized(self, quant_path, quant_filename, max_new_tokens=None,
device='balanced', trust_remote_code=True) -> BaseAWQForCausalLM:
device='balanced', trust_remote_code=True, fuse_layers=True) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code
quant_path, model_type, quant_filename, max_new_tokens, device, trust_remote_code=trust_remote_code, fuse_layers=fuse_layers
)
\ No newline at end of file
......@@ -250,7 +250,7 @@ class BaseAWQForCausalLM(nn.Module):
@classmethod
def from_quantized(self, model_path, model_type, model_filename, max_new_tokens=None,
device='balanced', torch_dtype=torch.float16, trust_remote_code=True,
safetensors=False, is_quantized=True):
safetensors=False, is_quantized=True, fuse_layers=False):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*"]
......@@ -298,6 +298,9 @@ class BaseAWQForCausalLM(nn.Module):
if is_quantized:
model = load_checkpoint_and_dispatch(model, model_filename, device_map=device, no_split_module_classes=[self.layer_type])
if fuse_layers:
self.fuse_layers(model)
else:
# If not quantized, must load with AutoModelForCausalLM
device_map = infer_auto_device_map(
......
from .base import BaseAWQForCausalLM
from awq.modules import make_quant_norm, make_quant_attn, make_fused_mlp
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "LlamaDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(awq_model):
make_quant_attn(awq_model, awq_model.device)
make_quant_norm(awq_model)
make_fused_mlp(awq_model)
@staticmethod
def get_model_layers(model: LlamaForCausalLM):
return model.model.layers
......
from .base import BaseAWQForCausalLM
from awq.modules import make_fused_mlp
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len"
@staticmethod
def fuse_layers(awq_model):
make_fused_mlp(awq_model)
@staticmethod
def get_model_layers(model):
return model.transformer.blocks
......
......@@ -7,6 +7,27 @@ from transformers.models.llama.modeling_llama import LlamaMLP
import awq_inference_engine
class QuantMPTMLP(nn.Module):
def __init__(
self,
up_proj,
act,
down_proj
):
super().__init__()
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.up_proj = up_proj
self.act = act
self.down_proj = down_proj
def forward(self, x: torch.Tensor):
x = x.reshape(-1, x.shape[-1])
x = awq_inference_engine.gemm_forward_cuda(x, self.up_proj_qweight, self.up_proj_scales, self.up_proj_qzeros, 8)
return self.down_proj(self.act(x))
class QuantLlamaMLP(nn.Module):
......@@ -57,10 +78,15 @@ def make_fused_mlp(m, parent_name=''):
"""
if isinstance(m, LlamaMLP):
return QuantLlamaMLP(m.gate_proj, m.down_proj, m.up_proj)
elif "mptmlp" in str(m.__class__).lower():
return QuantMPTMLP(m.up_proj, m.act, m.down_proj)
for name, child in m.named_children():
child = make_fused_mlp(child, parent_name=f"{parent_name}.{name}")
if isinstance(child, QuantLlamaMLP):
setattr(m, name, child)
return m
elif isinstance(child, QuantMPTMLP):
setattr(m, name, child)
return m
\ No newline at end of file
......@@ -116,13 +116,6 @@ if __name__ == '__main__':
else:
stream_generator = StreamGenerator
# Optimize AWQ quantized model
if args.precision == "W4A16" and isinstance(model, LlamaAWQForCausalLM):
from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp
make_quant_attn(model.model, args.device)
make_quant_norm(model.model)
make_fused_mlp(model.model)
model_prompter = get_prompter(model, args.model_path)
stop_token_ids = get_stop_token_ids(model, args.model_path)
count = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment