Unverified Commit 5b9f3c47 authored by Casper's avatar Casper Committed by GitHub
Browse files

Mixtral: Mixture of Experts quantization (#251)

parent 2350a4d0
......@@ -10,3 +10,4 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
\ No newline at end of file
......@@ -14,6 +14,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"gptj": GPTJAWQForCausalLM,
"gpt_bigcode": GptBigCodeAWQForCausalLM,
"mistral": MistralAWQForCausalLM,
"mixtral": MixtralAWQForCausalLM,
"gpt_neox": GPTNeoXAWQForCausalLM,
"aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM,
......
......@@ -12,7 +12,11 @@ from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from awq.utils.module import (
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
)
from transformers import (
AutoModelForCausalLM,
AutoConfig,
......@@ -24,7 +28,6 @@ from accelerate.big_modeling import (
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
from accelerate.utils import get_balanced_memory
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config):
......@@ -176,7 +179,7 @@ class BaseAWQForCausalLM(nn.Module):
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"])
ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
else:
ignore_patterns.append("*.safetensors*")
......@@ -215,6 +218,9 @@ class BaseAWQForCausalLM(nn.Module):
# Get every linear layer in a block
named_linears = get_named_linears(layer)
# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert)
# Replace activation functions
self._scale_activations(self, layer)
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# linear in
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[
w for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat['block_sparse_moe'],
module2inspect=module.block_sparse_moe,
))
# linear out
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f'block_sparse_moe.experts.{i}.w2'],
))
return layers
class MixtralFuser:
def __init__(self, model: OldMixtralForCausalLM):
self.model = model
self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'MixtralDecoderLayer'.lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldMixtralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = QuantFusedMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
......@@ -2,6 +2,40 @@ import os
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MixtralBlock(nn.Module):
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
moe, norm_1, norm_2, dev, max_seq_len
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
)
h = hidden_states.to(attn_output.device) + attn_output
out, _ = self.moe.forward(self.norm_2(h))
out = h + out
return out, None, past_key_value
class LlamaLikeBlock(nn.Module):
"""
LlamaLikeBlock is intended to be reused across blocks that have
......
......@@ -36,7 +36,7 @@ class QuantFusedMLP(nn.Module):
self.activation = activation
def forward(self, x):
def forward(self, x, routing_weights=None):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = self.linear(
......@@ -57,6 +57,9 @@ class QuantFusedMLP(nn.Module):
x = x.reshape(out_shape)
x = self.down_proj(x)
if routing_weights is not None:
x = routing_weights * x
return x
......
......@@ -2,8 +2,63 @@ import torch
import torch.nn as nn
from typing import List
from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock
class MixtralModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)
return MoeModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
router_logits=(),
)
class LlamaLikeModel(nn.Module):
"""
......
......@@ -10,7 +10,13 @@ from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name
from awq.utils.module import (
append_str_prefix,
get_op_name,
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize
)
class AwqQuantizer:
......@@ -70,13 +76,6 @@ class AwqQuantizer:
return w
def _exclude_layers_to_not_quantize(self, linear_layers):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in self.modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
......@@ -91,7 +90,7 @@ class AwqQuantizer:
named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude
named_linears = self._exclude_layers_to_not_quantize(named_linears)
named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)
input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
......@@ -387,6 +386,11 @@ class AwqQuantizer:
input_feat = defaultdict(list)
handles = []
# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if self.awq_model.model_type == "mixtral":
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}
for name in named_linears:
handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name,
......
......@@ -33,7 +33,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
layer.cuda()
scales.cuda()
if isinstance(prev_op, nn.Linear):
if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear):
scale_fc_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
......@@ -101,6 +104,25 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
for p in fc2.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(fc1.weight.device)
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor):
......
......@@ -41,4 +41,14 @@ def append_str_prefix(x, prefix):
elif isinstance(x, list):
return [append_str_prefix(y, prefix) for y in x]
else:
return x
\ No newline at end of file
return x
def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
if modules_to_not_convert is None:
return linear_layers
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
\ No newline at end of file
......@@ -7,7 +7,9 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True})
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
quant_path = 'mixtral-instruct-awq'
modules_to_not_convert = ["gate"]
quant_config = {
"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM",
"modules_to_not_convert": modules_to_not_convert
}
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, safetensors=True, **{"low_cpu_mem_usage": True}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(
tokenizer,
quant_config=quant_config,
modules_to_not_convert=modules_to_not_convert
)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment