Unverified Commit 5b9f3c47 authored by Casper's avatar Casper Committed by GitHub
Browse files

Mixtral: Mixture of Experts quantization (#251)

parent 2350a4d0
...@@ -10,3 +10,4 @@ from .gpt_neox import GPTNeoXAWQForCausalLM ...@@ -10,3 +10,4 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from .aquila import AquilaAWQForCausalLM from .aquila import AquilaAWQForCausalLM
from .yi import YiAWQForCausalLM from .yi import YiAWQForCausalLM
from .qwen import QwenAWQForCausalLM from .qwen import QwenAWQForCausalLM
from .mixtral import MixtralAWQForCausalLM
\ No newline at end of file
...@@ -14,6 +14,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -14,6 +14,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"gptj": GPTJAWQForCausalLM, "gptj": GPTJAWQForCausalLM,
"gpt_bigcode": GptBigCodeAWQForCausalLM, "gpt_bigcode": GptBigCodeAWQForCausalLM,
"mistral": MistralAWQForCausalLM, "mistral": MistralAWQForCausalLM,
"mixtral": MixtralAWQForCausalLM,
"gpt_neox": GPTNeoXAWQForCausalLM, "gpt_neox": GPTNeoXAWQForCausalLM,
"aquila": AquilaAWQForCausalLM, "aquila": AquilaAWQForCausalLM,
"Yi": YiAWQForCausalLM, "Yi": YiAWQForCausalLM,
......
...@@ -12,7 +12,11 @@ from huggingface_hub import snapshot_download ...@@ -12,7 +12,11 @@ from huggingface_hub import snapshot_download
from awq.quantize.quantizer import AwqQuantizer from awq.quantize.quantizer import AwqQuantizer
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name from awq.utils.module import (
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
)
from transformers import ( from transformers import (
AutoModelForCausalLM, AutoModelForCausalLM,
AutoConfig, AutoConfig,
...@@ -24,7 +28,6 @@ from accelerate.big_modeling import ( ...@@ -24,7 +28,6 @@ from accelerate.big_modeling import (
infer_auto_device_map, infer_auto_device_map,
load_checkpoint_and_dispatch, load_checkpoint_and_dispatch,
) )
from accelerate.utils import get_balanced_memory
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config): def __init__(self, model, model_type, is_quantized, config, quant_config):
...@@ -176,7 +179,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -176,7 +179,7 @@ class BaseAWQForCausalLM(nn.Module):
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"] ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
if safetensors: if safetensors:
ignore_patterns.extend(["*.pt*", "*.bin*"]) ignore_patterns.extend(["*.pt*", "*.bin*", "consolidated*"])
else: else:
ignore_patterns.append("*.safetensors*") ignore_patterns.append("*.safetensors*")
...@@ -215,6 +218,9 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -215,6 +218,9 @@ class BaseAWQForCausalLM(nn.Module):
# Get every linear layer in a block # Get every linear layer in a block
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert)
# Replace activation functions # Replace activation functions
self._scale_activations(self, layer) self._scale_activations(self, layer)
......
import tqdm
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)
@staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
# linear in
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[
w for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat['block_sparse_moe'],
module2inspect=module.block_sparse_moe,
))
# linear out
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f'block_sparse_moe.experts.{i}.w2'],
))
return layers
class MixtralFuser:
def __init__(self, model: OldMixtralForCausalLM):
self.model = model
self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'MixtralDecoderLayer'.lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldMixtralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
)
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = QuantFusedMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
)
blocks.append(MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
...@@ -2,6 +2,40 @@ import os ...@@ -2,6 +2,40 @@ import os
import torch.nn as nn import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
class MixtralBlock(nn.Module):
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
moe, norm_1, norm_2, dev, max_seq_len
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = n_kv_heads
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
)
h = hidden_states.to(attn_output.device) + attn_output
out, _ = self.moe.forward(self.norm_2(h))
out = h + out
return out, None, past_key_value
class LlamaLikeBlock(nn.Module): class LlamaLikeBlock(nn.Module):
""" """
LlamaLikeBlock is intended to be reused across blocks that have LlamaLikeBlock is intended to be reused across blocks that have
......
...@@ -36,7 +36,7 @@ class QuantFusedMLP(nn.Module): ...@@ -36,7 +36,7 @@ class QuantFusedMLP(nn.Module):
self.activation = activation self.activation = activation
def forward(self, x): def forward(self, x, routing_weights=None):
out_shape = x.shape[:-1] + (self.intermediate_size,) out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1]) x = x.reshape(-1, x.shape[-1])
gate_output = self.linear( gate_output = self.linear(
...@@ -57,6 +57,9 @@ class QuantFusedMLP(nn.Module): ...@@ -57,6 +57,9 @@ class QuantFusedMLP(nn.Module):
x = x.reshape(out_shape) x = x.reshape(out_shape)
x = self.down_proj(x) x = self.down_proj(x)
if routing_weights is not None:
x = routing_weights * x
return x return x
......
...@@ -2,8 +2,63 @@ import torch ...@@ -2,8 +2,63 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import List from typing import List
from awq.utils import fused_utils from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock
class MixtralModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h = self.norm(h)
return MoeModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
router_logits=(),
)
class LlamaLikeModel(nn.Module): class LlamaLikeModel(nn.Module):
""" """
......
...@@ -10,7 +10,13 @@ from awq.utils.utils import clear_memory ...@@ -10,7 +10,13 @@ from awq.utils.utils import clear_memory
from awq.utils.calib_data import get_calib_dataset from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip from awq.quantize.scale import apply_scale, apply_clip
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import append_str_prefix, get_op_name, get_named_linears, set_op_by_name from awq.utils.module import (
append_str_prefix,
get_op_name,
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize
)
class AwqQuantizer: class AwqQuantizer:
...@@ -70,13 +76,6 @@ class AwqQuantizer: ...@@ -70,13 +76,6 @@ class AwqQuantizer:
return w return w
def _exclude_layers_to_not_quantize(self, linear_layers):
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in self.modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
def quantize(self): def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"): for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device # Move module and inputs to correct device
...@@ -91,7 +90,7 @@ class AwqQuantizer: ...@@ -91,7 +90,7 @@ class AwqQuantizer:
named_linears = get_named_linears(self.modules[i]) named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude # Filter out the linear layers we don't want to exclude
named_linears = self._exclude_layers_to_not_quantize(named_linears) named_linears = exclude_layers_to_not_quantize(named_linears, self.modules_to_not_convert)
input_feat = self._get_input_feat(self.modules[i], named_linears) input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory() clear_memory()
...@@ -387,6 +386,11 @@ class AwqQuantizer: ...@@ -387,6 +386,11 @@ class AwqQuantizer:
input_feat = defaultdict(list) input_feat = defaultdict(list)
handles = [] handles = []
# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if self.awq_model.model_type == "mixtral":
named_linears = {**named_linears, "block_sparse_moe": layer.block_sparse_moe}
for name in named_linears: for name in named_linears:
handles.append(named_linears[name].register_forward_hook( handles.append(named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name, functools.partial(cache_input_hook, name=name,
......
...@@ -33,7 +33,10 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -33,7 +33,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
layer.cuda() layer.cuda()
scales.cuda() scales.cuda()
if isinstance(prev_op, nn.Linear): if isinstance(prev_op, nn.Linear) and type(layers) == list and isinstance(layers[0], nn.Linear):
scale_fc_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.Linear):
assert len(layers) == 1 assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
...@@ -101,6 +104,25 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor): ...@@ -101,6 +104,25 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
for p in fc2.parameters(): for p in fc2.parameters():
assert torch.isnan(p).sum() == 0 assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(fc1.weight.device)
fc1.weight[-scales.size(0):].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad() @torch.no_grad()
def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor): def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor):
......
...@@ -41,4 +41,14 @@ def append_str_prefix(x, prefix): ...@@ -41,4 +41,14 @@ def append_str_prefix(x, prefix):
elif isinstance(x, list): elif isinstance(x, list):
return [append_str_prefix(y, prefix) for y in x] return [append_str_prefix(y, prefix) for y in x]
else: else:
return x return x
\ No newline at end of file
def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
if modules_to_not_convert is None:
return linear_layers
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
\ No newline at end of file
...@@ -7,7 +7,9 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version": ...@@ -7,7 +7,9 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":
# Load model # Load model
# NOTE: pass safetensors=True to load safetensors # NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(model_path, **{"low_cpu_mem_usage": True}) model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize # Quantize
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'mistralai/Mixtral-8x7B-Instruct-v0.1'
quant_path = 'mixtral-instruct-awq'
modules_to_not_convert = ["gate"]
quant_config = {
"zero_point": True, "q_group_size": 128, "w_bit": 4, "version": "GEMM",
"modules_to_not_convert": modules_to_not_convert
}
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, safetensors=True, **{"low_cpu_mem_usage": True}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(
tokenizer,
quant_config=quant_config,
modules_to_not_convert=modules_to_not_convert
)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment