Unverified Commit bcaa8a36 authored by Casper's avatar Casper Committed by GitHub
Browse files

v0.2.0 (#330)


Co-authored-by: default avatarjinz2014 <7799920+jinz2014@users.noreply.github.com>
Co-authored-by: default avatarJin Z <5zj@cousteau.ftpn.ornl.gov>
parent c69d3b65
......@@ -6,13 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM
MistralForCausalLM as OldMistralForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MistralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldMistralForCausalLM):
......@@ -22,53 +23,65 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def get_model_layers(model: OldMistralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module: OldMistralDecoderLayer):
return dict(
is_scalable=False
)
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldMistralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(
module: OldMistralDecoderLayer, input_feat, module_kwargs
):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
......@@ -78,10 +91,11 @@ class MistralFuser:
self.model = model
self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower()
(name, module)
for name, module in self.model.named_modules()
if "MistralDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
......@@ -92,29 +106,30 @@ class MistralFuser:
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
......
import tqdm
import torch
from typing import List, Tuple
from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel
from awq.modules.fused.moe import FusedSparseMoeBlock
from awq.utils.fused_utils import fuse_qkv, fuse_linears
from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM
MixtralForCausalLM as OldMixtralForCausalLM,
)
from awq.modules.linear import WQLinear_GEMM
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
max_seq_len_key = "max_position_embeddings"
modules_to_not_convert = ["gate"]
@staticmethod
def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model)
fuser.fuse_transformer()
@staticmethod
def get_model_layers(model: OldMixtralForCausalLM):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kwargs):
def get_layers_for_scaling(
module: OldMixtralDecoderLayer, input_feat, module_kwargs
):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear in
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[
w for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat['block_sparse_moe'],
module2inspect=module.block_sparse_moe,
))
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[
w
for expert in module.block_sparse_moe.experts
for w in [expert.w1, expert.w3]
],
inp=input_feat["block_sparse_moe"],
module2inspect=module.block_sparse_moe,
)
)
# linear out
for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f'block_sparse_moe.experts.{i}.w2'],
))
layers.append(
dict(
prev_op=expert.w3,
layers=[expert.w2],
inp=input_feat[f"block_sparse_moe.experts.{i}.w2"],
)
)
return layers
......@@ -81,49 +99,89 @@ class MixtralFuser:
self.model = model
self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'MixtralDecoderLayer'.lower() in module.__class__.__name__.lower()
(name, module)
for name, module in self.model.named_modules()
if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
module: OldMixtralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv(
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
module.post_attention_layernorm.variance_epsilon,
)
sparse_moe = module.block_sparse_moe
if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM) and torch.cuda.device_count() == 1:
fused_w1w3s = [
fuse_linears(
[
sparse_moe.experts[i].w1,
sparse_moe.experts[i].w3,
],
device,
)
for i in range(len(sparse_moe.experts))
]
stacked_w1w3s = fuse_linears(
fused_w1w3s, device, dim=0, operation=torch.stack
)
stacked_w2s = fuse_linears(
[expert.w2 for expert in sparse_moe.experts],
device,
dim=0,
operation=torch.stack,
)
sparse_moe = FusedSparseMoeBlock(
top_k=sparse_moe.top_k,
gate=sparse_moe.gate,
ws=stacked_w1w3s,
w2s=stacked_w2s,
)
blocks.append(
MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)
blocks.append(MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
))
model_norm = FasterTransformerRMSNorm(
self.model.model.norm.weight,
self.model.model.norm.variance_epsilon,
)
self.model.model = MixtralModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
model_norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len"
max_seq_len_key = "max_seq_len"
@staticmethod
def fuse_layers(model: MptForCausalLM):
......@@ -13,73 +14,84 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks
@staticmethod
def get_act_for_scaling(module: OldMptBlock):
return dict(
is_scalable=True,
scale_name="ffn.act",
scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features
scale_shape=module.ffn.up_proj.out_features,
)
@staticmethod
def move_embed(model: MptForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod
def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
layers = []
if module_kwargs.get("output_attentions") is not None:
module_kwargs.pop("output_attentions")
# attention input
layers.append(dict(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat['attn.Wqkv'],
module2inspect=module.attn,
kwargs=module_kwargs
))
layers.append(
dict(
prev_op=module.norm_1,
layers=[module.attn.Wqkv],
inp=input_feat["attn.Wqkv"],
module2inspect=module.attn,
kwargs=module_kwargs,
)
)
# attention output
layers.append(dict(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat['attn.out_proj']
))
layers.append(
dict(
prev_op=module.attn.Wqkv,
layers=[module.attn.out_proj],
inp=input_feat["attn.out_proj"],
)
)
# linear 1
layers.append(dict(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat['ffn.up_proj'],
module2inspect=module.ffn
))
layers.append(
dict(
prev_op=module.norm_2,
layers=[module.ffn.up_proj],
inp=input_feat["ffn.up_proj"],
module2inspect=module.ffn,
)
)
# linear 2
layers.append(dict(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat['ffn.down_proj']
))
layers.append(
dict(
prev_op=module.ffn.act,
layers=[module.ffn.down_proj],
inp=input_feat["ffn.down_proj"],
)
)
return layers
from typing import List, Tuple
from awq.utils.utils import set_module_name
from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel
class MptFuser:
def __init__(self, model: MptForCausalLM):
self.model = model
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules()
if 'mptblock' in module.__class__.__name__.lower()
(name, module)
for name, module in self.model.named_modules()
if "mptblock" in module.__class__.__name__.lower()
]
def fuse_transformer(self):
......@@ -87,17 +99,19 @@ class MptFuser:
module: OldMptBlock
for module in self.model.transformer.blocks:
blocks.append(MPTBlock(
self.model.config.d_model,
self.model.config.n_heads,
module.attn.Wqkv,
module.attn.out_proj,
module.ffn,
module.norm_1,
module.norm_2,
next(iter(module.state_dict().values())).device,
self.model.config.max_new_tokens
))
blocks.append(
MPTBlock(
self.model.config.d_model,
self.model.config.n_heads,
module.attn.Wqkv,
module.attn.out_proj,
module.ffn,
module.norm_1,
module.norm_2,
next(iter(module.state_dict().values())).device,
self.model.config.max_seq_len,
)
)
self.model.transformer = MPTModel(
self.model.config.vocab_size,
......@@ -106,4 +120,4 @@ class MptFuser:
self.model.transformer.norm_f,
)
setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
\ No newline at end of file
setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
from .base import BaseAWQForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class OptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "OPTDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def get_model_layers(model: OPTForCausalLM):
return model.model.decoder.layers
@staticmethod
def get_act_for_scaling(module: OPTDecoderLayer):
return dict(
is_scalable=False
)
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OPTForCausalLM, device: str):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(
device
)
@staticmethod
def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.self_attn_layer_norm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn,
kwargs=module_kwargs,
))
layers.append(
dict(
prev_op=module.self_attn_layer_norm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat['self_attn.out_proj'],
))
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.out_proj],
inp=input_feat["self_attn.out_proj"],
)
)
# linear 1
layers.append(dict(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat['fc1'],
))
layers.append(
dict(
prev_op=module.final_layer_norm,
layers=[module.fc1],
inp=input_feat["fc1"],
)
)
# linear 2
layers.append(dict(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat['fc2'],
))
layers.append(
dict(
prev_op=module.fc1,
layers=[module.fc2],
inp=input_feat["fc2"],
)
)
return layers
\ No newline at end of file
return layers
......@@ -3,7 +3,7 @@ from .base import BaseAWQForCausalLM
class QwenAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "QWenBlock"
max_new_tokens_key = "seq_length"
max_seq_len_key = "seq_length"
@staticmethod
def get_model_layers(model):
......
......@@ -6,14 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer as OldQwen2DecoderLayer,
Qwen2ForCausalLM as OldQwen2ForCausalLM
Qwen2ForCausalLM as OldQwen2ForCausalLM,
)
from awq.modules.fused.norm import FasterTransformerRMSNorm
class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Qwen2DecoderLayer"
max_new_tokens_key = "max_position_embeddings"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model: OldQwen2ForCausalLM):
......@@ -26,9 +26,7 @@ class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def get_act_for_scaling(module: OldQwen2DecoderLayer):
return dict(
is_scalable=False
)
return dict(is_scalable=False)
@staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str):
......@@ -39,37 +37,49 @@ class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
layers = []
# attention input
layers.append(dict(
prev_op=module.input_layernorm,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
layers.append(
dict(
prev_op=module.input_layernorm,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
layers.append(
dict(
prev_op=module.post_attention_layernorm,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
......@@ -79,8 +89,9 @@ class Qwen2Fuser:
self.model = model
self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [
(name, module) for name, module in self.model.named_modules()
if 'Qwen2DecoderLayer'.lower() in module.__class__.__name__.lower()
(name, module)
for name, module in self.model.named_modules()
if "Qwen2DecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
......@@ -93,28 +104,29 @@ class Qwen2Fuser:
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
module.input_layernorm.weight, module.input_layernorm.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon
module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
......
......@@ -6,9 +6,10 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.norm import FasterTransformerRMSNorm
class YiAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "YiDecoderLayer"
max_new_tokens_key = "max_position_embeddings"
max_seq_len_key = "max_position_embeddings"
@staticmethod
def fuse_layers(model):
......@@ -18,53 +19,63 @@ class YiAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod
def get_model_layers(model):
return model.model.layers
@staticmethod
def get_act_for_scaling(module):
return dict(
is_scalable=False
)
return dict(is_scalable=False)
@staticmethod
def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = []
# attention input
layers.append(dict(
prev_op=module.ln1,
layers=[module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.v_proj],
inp=input_feat['self_attn.q_proj'],
module2inspect=module.self_attn, kwargs=module_kwargs,
))
layers.append(
dict(
prev_op=module.ln1,
layers=[
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat['self_attn.o_proj'],
))
layers.append(
dict(
prev_op=module.self_attn.v_proj,
layers=[module.self_attn.o_proj],
inp=input_feat["self_attn.o_proj"],
)
)
# linear 1
layers.append(dict(
prev_op=module.ln2,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat['mlp.gate_proj'],
module2inspect=module.mlp,
))
layers.append(
dict(
prev_op=module.ln2,
layers=[module.mlp.gate_proj, module.mlp.up_proj],
inp=input_feat["mlp.gate_proj"],
module2inspect=module.mlp,
)
)
# linear 2
layers.append(dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat['mlp.down_proj'],
))
layers.append(
dict(
prev_op=module.mlp.up_proj,
layers=[module.mlp.down_proj],
inp=input_feat["mlp.down_proj"],
)
)
return layers
......@@ -74,10 +85,11 @@ class YiFuser:
self.model = model
self.yi_blocks: List[Tuple[str, object]] = [
(name, module) for name, module in self.model.named_modules()
if 'YiDecoderLayer'.lower() in module.__class__.__name__.lower()
(name, module)
for name, module in self.model.named_modules()
if "YiDecoderLayer".lower() in module.__class__.__name__.lower()
]
def fuse_transformer(self):
blocks = []
......@@ -87,30 +99,30 @@ class YiFuser:
module,
module.self_attn.q_proj,
module.self_attn.k_proj,
module.self_attn.v_proj
module.self_attn.v_proj,
)
norm_1 = FasterTransformerRMSNorm(
module.ln1.weight,
module.ln1.variance_epsilon
module.ln1.weight, module.ln1.variance_epsilon
)
norm_2 = FasterTransformerRMSNorm(
module.ln2.weight,
module.ln2.variance_epsilon
module.ln2.weight, module.ln2.variance_epsilon
)
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
))
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
)
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
......
import torch.nn as nn
class ScaledActivation(nn.Module):
def __init__(self, module, scales):
super().__init__()
self.act = module
self.scales = nn.Parameter(scales.data)
def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
......@@ -9,6 +9,7 @@ from awq.utils.fused_utils import get_attention_shapes
try:
import awq_ft_ext
FT_INSTALLED = True
except:
FT_INSTALLED = False
......@@ -16,6 +17,7 @@ except:
HF_NEW_CACHE_FORMAT = False
import transformers
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT:
......@@ -25,12 +27,12 @@ if HF_NEW_CACHE_FORMAT:
class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta):
super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis(
hidden_size // n_heads, max_seq_len * 2, rope_theta
).to(device),
requires_grad=False
requires_grad=False,
)
@staticmethod
......@@ -58,18 +60,21 @@ class RoPE(nn.Module):
)
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
super(ALiBi, self).__init__()
# Initialize ALiBi slopes and bias
slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max)
slopes, bias = self.build_alibi_bias(
n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
)
self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)
......@@ -79,27 +84,42 @@ class ALiBi(nn.Module):
m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads:
slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1)
@staticmethod
def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len)
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, 1, seq_len
)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen]
return scores
class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len,
use_alibi=False, attention_shapes=None, rope_theta=10000):
def __init__(
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
dev,
max_seq_len=2048,
use_alibi=False,
attention_shapes=None,
rope_theta=10000,
**kwargs
):
super().__init__()
self.hidden_size = hidden_size
self.n_heads = n_heads
......@@ -111,17 +131,29 @@ class QuantAttentionFused(nn.Module):
self.start_pos = 0
self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if kwargs.get("max_new_tokens") is not None:
max_seq_len = kwargs["max_new_tokens"]
self.max_seq_len = max_seq_len
self.is_hf_transformers = False
self.rope_theta = rope_theta
# attention shapes for self attention
self.attention_shapes = get_attention_shapes(
attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim
attention_shapes,
max_seq_len,
self.cache_batch_size,
n_heads,
n_kv_heads,
self.head_dim,
)
# cache store that rolls cache
self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], self.max_seq_len, dev
self.attention_shapes["cache_v"],
self.attention_shapes["cache_k"],
self.max_seq_len,
dev,
)
if use_alibi:
......@@ -133,8 +165,10 @@ class QuantAttentionFused(nn.Module):
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta)
self.rotary_dim = self.head_dim
self.is_neox = True
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs):
def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
):
bsz, seqlen, _ = hidden_states.shape
# Reallocate cache if batch size changes
......@@ -147,18 +181,22 @@ class QuantAttentionFused(nn.Module):
self.cache_batch_size = bsz
# Always reset to 0
self.start_pos = 0
self.start_pos = 0
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration
if self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None:
if (
self.is_hf_transformers
and "past_key_value" in kwargs
and kwargs["past_key_value"] is None
):
self.start_pos = 0
xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv)
......@@ -179,21 +217,22 @@ class QuantAttentionFused(nn.Module):
.permute(0, 2, 3, 1, 4)
.contiguous()
)
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
# Only necessary to retrieve from cache when we are not processing context
if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
keys = xk
values = xv
if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(
values, dim=2, repeats=self.n_kv_groups
)
xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2)
values = values.transpose(1, 2)
......@@ -204,7 +243,9 @@ class QuantAttentionFused(nn.Module):
# When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen)
scores = (
scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
......@@ -215,25 +256,25 @@ class QuantAttentionFused(nn.Module):
alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = awq_ft_ext.single_query_attention(
xq, # query
xk, # key
xv, # value
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
self.rope_theta, # rotary embedding base
self.is_neox, # is neox
xq, # query
xk, # key
xv, # value
self.cache.k, # key cache
self.cache.v, # value cache
None, # length per sample
alibi_slopes, # alibi slopes
self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension
self.rope_theta, # rotary embedding base
self.is_neox, # is neox
)
attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen
# past_key_value is replaced with cache_v, cache_k, returning empty data
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
......
......@@ -2,10 +2,21 @@ import os
import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused
class MixtralBlock(nn.Module):
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
moe, norm_1, norm_2, dev, max_seq_len, rope_theta
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
moe,
norm_1,
norm_2,
dev,
max_seq_len,
rope_theta,
):
super().__init__()
self.n_heads = n_heads
......@@ -13,37 +24,62 @@ class MixtralBlock(nn.Module):
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
rope_theta=rope_theta,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.moe = moe
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
attention_mask=attention_mask,
)
h = hidden_states.to(attn_output.device) + attn_output
out, _ = self.moe.forward(self.norm_2(h))
out = self.moe.forward(self.norm_2(h))
out = h + out
return out, None, past_key_value
class LlamaLikeBlock(nn.Module):
"""
LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj,
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta=10000, use_alibi=False
self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
mlp,
norm_1,
norm_2,
dev,
max_seq_len,
rope_theta=10000,
use_alibi=False,
):
super().__init__()
self.n_heads = n_heads
......@@ -51,21 +87,33 @@ class LlamaLikeBlock(nn.Module):
self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta
self.hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out,
past_key_value=past_key_value,
attention_mask=attention_mask
attention_mask=attention_mask,
)
h = hidden_states.to(attn_output.device) + attn_output
......@@ -73,23 +121,46 @@ class LlamaLikeBlock(nn.Module):
return out, None, past_key_value
class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len):
def __init__(
self,
hidden_size,
n_heads,
qkv_layer,
o_proj,
mpt_mlp,
norm_1,
norm_2,
dev,
max_seq_len,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 0
self.hidden_size = hidden_size
self.norm_1 = norm_1
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=True
hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=True,
).to(dev)
self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev)
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
......@@ -98,16 +169,29 @@ class MPTBlock(nn.Module):
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
use_cache=True,
)
h = hidden_states.to(attn_output.device) + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len,
input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True):
def __init__(
self,
hidden_size,
n_heads,
qkv_layer,
o_proj,
mlp,
dev,
max_seq_len,
input_layernorm=None,
ln_attn=None,
ln_mlp=None,
new_decoder_arch=True,
):
super().__init__()
self.n_heads = n_heads
self.n_kv_heads = 8 if new_decoder_arch else 0
......@@ -117,33 +201,52 @@ class FalconDecoderLayer(nn.Module):
if new_decoder_arch:
attention_shapes = None
else:
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads)
attention_shapes = self._get_attention_shapes(
n_heads, max_seq_len, self.hidden_size // n_heads
)
# TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj,
dev=dev, max_seq_len=max_seq_len, use_alibi=False,
attention_shapes=attention_shapes
hidden_size,
self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
attention_shapes=attention_shapes,
).to(dev)
if new_decoder_arch:
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp
else:
self.input_layernorm = input_layernorm # before attention
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
self.device = dev
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = {
# following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,),
"cache_v": (
batch_size,
1,
max_seq_len,
head_dim,
),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,),
"xqkv_view": (n_heads+2, head_dim),
"cache_k": (
batch_size,
1,
head_dim // 8,
max_seq_len,
8,
),
"xqkv_view": (n_heads + 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]],
......@@ -153,27 +256,32 @@ class FalconDecoderLayer(nn.Module):
"xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim)
"single_xv_view": (1, head_dim),
}
return self.attention_shapes
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
):
if self.new_decoder_arch:
layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states)
else:
layernorm_out = self.input_layernorm(hidden_states)
attn_output, _, past_key_value = self.attn.forward(
hidden_states=layernorm_out,
past_key_value=past_key_value,
attention_mask=attention_mask,
position_ids=None,
output_attentions=False,
use_cache=True
use_cache=True,
)
h_attn = hidden_states.to(attn_output.device) + attn_output
......@@ -182,7 +290,7 @@ class FalconDecoderLayer(nn.Module):
h_mlp = self.mlp.forward(mlp_layernorm_out)
else:
h_mlp = self.mlp.forward(layernorm_out)
out = h_attn + h_mlp
return out, None, past_key_value
\ No newline at end of file
return out, None, past_key_value
import torch
class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device):
"""
The window size is the same as the max_new_tokens. The window will
automatically roll once max_new_tokens is exceeded.
The window size is the same as the max_seq_len. The window will
automatically roll once max_seq_len is exceeded.
"""
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half()
self.max_seq_len = max_seq_len
def get_kv(self, batch_size, start_pos, seqlen, head_dim):
"""
Gets the key-value store in correct shapes.
"""
xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous()
xv = (
self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
)
xk = (
self.k[:batch_size, :, :, : start_pos + seqlen, :]
.transpose(2, 3)
.contiguous()
)
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
"""
Updates the values in the key-value store.
......@@ -41,19 +48,23 @@ class WindowedCache:
# Zero out the new part
self.v[:, :, -n:, :] = 0
self.k[:, :, :, -n:, :] = 0
return start_pos - n
def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)
def increase_batch_size(self, to_bsz):
"""Dynamically allocate new kv when batch size changes."""
self.v = torch.zeros(to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device)
self.k = torch.zeros(to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device)
self.v = torch.zeros(
to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device
)
self.k = torch.zeros(
to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device
)
def decrease_batch_size(self, to_bsz):
"""Dynamically remove part of cache if batch size changes."""
self.v = self.v[:to_bsz, :, :, :]
self.k = self.k[:to_bsz, :, :, :, :]
\ No newline at end of file
self.k = self.k[:to_bsz, :, :, :, :]
......@@ -5,26 +5,28 @@ from awq.modules.linear.gemv import WQLinear_GEMV
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class QuantFusedMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj,
activation = F.silu,
activation=F.silu,
):
super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros)
self.register_buffer('up_proj_qweight', up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros)
self.register_buffer("gate_proj_qweight", gate_proj.qweight)
self.register_buffer("gate_proj_scales", gate_proj.scales)
self.register_buffer("gate_proj_qzeros", gate_proj.qzeros)
self.register_buffer("up_proj_qweight", up_proj.qweight)
self.register_buffer("up_proj_scales", up_proj.scales)
self.register_buffer("up_proj_qzeros", up_proj.qzeros)
self.in_features = gate_proj.in_features
self.intermediate_size = gate_proj.out_features
......@@ -66,17 +68,13 @@ class QuantFusedMLP(nn.Module):
x = routing_weights * x
return x
class QuantLlamaMLP(QuantFusedMLP):
r"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users
QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead.
"""
def __init__(
self,
gate_proj,
down_proj,
up_proj
):
super().__init__(gate_proj, down_proj, up_proj)
\ No newline at end of file
def __init__(self, gate_proj, down_proj, up_proj):
super().__init__(gate_proj, down_proj, up_proj)
......@@ -2,8 +2,16 @@ import torch
import torch.nn as nn
from typing import List
from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
MoeModelOutputWithPast,
)
from awq.modules.fused.block import (
MPTBlock,
FalconDecoderLayer,
LlamaLikeBlock,
MixtralBlock,
)
class MixtralModel(nn.Module):
......@@ -47,8 +55,10 @@ class MixtralModel(nn.Module):
h,
mask,
)
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h, _, past_key_value = layer(
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm(h)
return MoeModelOutputWithPast(
......@@ -65,6 +75,7 @@ class LlamaLikeModel(nn.Module):
LlamaLikeModel is intended to be reused across models that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
......@@ -72,12 +83,19 @@ class LlamaLikeModel(nn.Module):
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids,
self.last_forward_num_tokens
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
......@@ -89,7 +107,7 @@ class LlamaLikeModel(nn.Module):
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h
type_as=h,
)
for layer in self.blocks:
......@@ -99,14 +117,17 @@ class LlamaLikeModel(nn.Module):
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f):
......@@ -120,10 +141,17 @@ class MPTModel(nn.Module):
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
def forward(
self,
input_ids,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids,
self.last_forward_num_tokens
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
......@@ -135,7 +163,7 @@ class MPTModel(nn.Module):
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h
type_as=h,
)
for layer in self.blocks:
......@@ -145,14 +173,17 @@ class MPTModel(nn.Module):
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
class FalconModel(nn.Module):
def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
......@@ -166,10 +197,17 @@ class FalconModel(nn.Module):
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
def forward(
self,
input_ids,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids,
self.last_forward_num_tokens
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
......@@ -181,7 +219,7 @@ class FalconModel(nn.Module):
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h
type_as=h,
)
for layer in self.blocks:
......@@ -191,11 +229,13 @@ class FalconModel(nn.Module):
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
import torch
import triton
from typing import Dict
import triton.language as tl
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class FusedSparseMoeBlock(torch.nn.Module):
def __init__(
self,
top_k,
gate,
ws,
w2s,
):
super().__init__()
self.gate = gate
self.top_k = top_k
self.ws = ws
self.w2s = w2s
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = apply_moe_weights(
self.ws,
self.w2s,
hidden_states,
router_logits,
self.top_k,
renormalize=True,
)
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
def apply_moe_weights(
w1: Dict[str, torch.Tensor],
w2: Dict[str, torch.Tensor],
x: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
dequant_w1 = awq_ext.dequantize_weights_cuda(
w1.qweight, w1.scales, w1.qzeros, 0, 0, 0, False
).permute(0, 2, 1)
dequant_w2 = awq_ext.dequantize_weights_cuda(
w2.qweight, w2.scales, w2.qzeros, 0, 0, 0, False
).permute(0, 2, 1)
return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk, renormalize)
topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
(sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size(
topk_ids, 16, w1.qweight.shape[0]
)
x = x.view(x.shape[0], 1, *x.shape[1:])
gate_up = awq_ext.grouped_gemm_forward(
x,
w1.qweight,
w1.scales,
w1.qzeros,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
8,
)
out = torch.empty(
(gate_up.shape[:-1] + (gate_up.shape[-1] // 2,)), dtype=x.dtype, device=x.device
)
awq_ext.silu_and_mul(out, gate_up)
out = awq_ext.grouped_gemm_forward(
out,
w2.qweight,
w2.scales,
w2.qzeros,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
8,
)
return torch.sum(out, dim=1)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,
and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: int):
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1),),
dtype=torch.int32,
device=topk_ids.device,
)
expert_ids = torch.empty(
(topk_ids.numel() + num_experts,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
awq_ext.moe_alig_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict,
):
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
grid = lambda META: (
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid](
A,
B,
C,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
**config,
)
def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
"""Compute top-k indice and weights from gating logits
Args:
gating_output (torch.Tensor): The output of the gating operation (before softmax).
topk (int): The number of top-k experts to select.
renormalize (bool): If True, renormalize the top-k weights to sum to 1.
"""
M = gating_output.shape[0]
if torch.version.hip is not None:
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=gating_output.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=gating_output.device
)
awq_ext.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
# assert w1.is_contiguous(), "Expert weights1 must be contiguous"
# assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
E, N, _ = w1.shape
topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
if topk_ids.numel() <= w1.shape[0]:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E
)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
)
awq_ext.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
)
if inplace:
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
......@@ -2,4 +2,4 @@ from .exllama import WQLinear_Exllama
from .exllamav2 import WQLinear_ExllamaV2
from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin
\ No newline at end of file
from .marlin import WQLinear_Marlin
......@@ -11,6 +11,7 @@ try:
except:
AWQ_INSTALLED = False
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
@staticmethod
......@@ -24,45 +25,29 @@ class WQLinearMMFunction(Function):
w_bit=4,
group_size=128,
bias=None,
out_features=0
out_features=0,
):
# The forward pass can use ctx.
ctx.save_for_backward(x, qweight, qzeros, scales, bias)
ctx.out_features = out_features
out_shape = x.shape[:-1] + (out_features, )
out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)
if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
0,
0,
0,
False
qweight, scales, qzeros, 0, 0, 0, False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]),
qweight,
scales,
qzeros,
8
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
)
else:
out = dequantize_gemm(
qweight,
qzeros,
scales,
w_bit,
group_size
)
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)
out = out + bias if bias is not None else out
......@@ -71,7 +56,7 @@ class WQLinearMMFunction(Function):
# always want 3D tensor if tensor is 2D
if len(out.shape) == 2:
out = out.unsqueeze(0)
return out
@staticmethod
......@@ -79,13 +64,7 @@ class WQLinearMMFunction(Function):
input, qweight, qzeros, scales, bias = ctx.saved_tensors
weights = awq_ext.dequantize_weights_cuda(
qweight,
scales,
qzeros,
1,
0,
0,
False
qweight, scales, qzeros, 1, 0, 0, False
)
if ctx.needs_input_grad[0]:
......@@ -98,7 +77,9 @@ class WQLinearMMFunction(Function):
class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False):
def __init__(
self, w_bit, group_size, in_features, out_features, bias, dev, training=False
):
super().__init__()
if w_bit not in [4]:
......
......@@ -54,7 +54,7 @@ class WQLinear_Marlin(nn.Module):
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
self.max_par = 8 # partitioning for large inputs
self.max_par = 8 # partitioning for large inputs
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
......
......@@ -117,7 +117,7 @@ class AwqQuantizer:
best_device = "cuda:" + str(i % torch.cuda.device_count())
else:
best_device = get_best_device()
self.modules[i] = self.modules[i].to(best_device)
common_device = next(self.modules[i].parameters()).device
......@@ -190,15 +190,15 @@ class AwqQuantizer:
linear_layer.weight.data
)
if self.version == "GEMM":
if self.version == "gemm":
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.version == "GEMV":
elif self.version == "gemv":
q_linear_module = WQLinear_GEMV
elif self.version == "Marlin":
elif self.version == "marlin":
q_linear_module = WQLinear_Marlin
else:
......@@ -355,7 +355,9 @@ class AwqQuantizer:
continue
named_linears[name].to(get_best_device())
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name])
max_val = self._compute_best_clip(
named_linears[name].weight, input_feat[name]
)
clip_list.append((name, max_val))
named_linears[name].cpu()
......@@ -481,7 +483,9 @@ class AwqQuantizer:
clear_memory()
if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(best_device)
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
best_device
)
return modules, layer_kwargs, inps
......
This diff is collapsed.
......@@ -3,33 +3,41 @@ import logging
from typing import List, Union
from datasets import load_dataset
def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None, n_samples=512, block_size=512,
split="train", text_column="text"):
def get_calib_dataset(
data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None,
n_samples=512,
block_size=512,
split="train",
text_column="text",
):
if isinstance(data, str):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list):
if isinstance(data[0], str):
dataset = [{text_column: text} for text in data]
elif isinstance(data[0][0], int):
dataset = data
dataset = data
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.")
" or a list of list of int for tokenized words."
)
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.")
" or a list of list of int for tokenized words."
)
samples = []
n_run = 0
for data in dataset:
......@@ -52,4 +60,6 @@ def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval",
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
logging.debug(f" * Split into {n_split} blocks")
return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)]
return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split)
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment