Unverified Commit bcaa8a36 authored by Casper's avatar Casper Committed by GitHub
Browse files

v0.2.0 (#330)


Co-authored-by: default avatarjinz2014 <7799920+jinz2014@users.noreply.github.com>
Co-authored-by: default avatarJin Z <5zj@cousteau.ftpn.ornl.gov>
parent c69d3b65
...@@ -6,13 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock ...@@ -6,13 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from transformers.models.mistral.modeling_mistral import ( from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer, MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM MistralForCausalLM as OldMistralForCausalLM,
) )
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM): class MistralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MistralDecoderLayer" layer_type = "MistralDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: OldMistralForCausalLM): def fuse_layers(model: OldMistralForCausalLM):
...@@ -22,53 +23,65 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM): ...@@ -22,53 +23,65 @@ class MistralAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_model_layers(model: OldMistralForCausalLM): def get_model_layers(model: OldMistralForCausalLM):
return model.model.layers return model.model.layers
@staticmethod @staticmethod
def get_act_for_scaling(module: OldMistralDecoderLayer): def get_act_for_scaling(module: OldMistralDecoderLayer):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model: OldMistralForCausalLM, device: str): def move_embed(model: OldMistralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: OldMistralDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(
module: OldMistralDecoderLayer, input_feat, module_kwargs
):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
prev_op=module.input_layernorm, dict(
layers=[module.self_attn.q_proj, prev_op=module.input_layernorm,
module.self_attn.k_proj, module.self_attn.v_proj], layers=[
inp=input_feat['self_attn.q_proj'], module.self_attn.q_proj,
module2inspect=module.self_attn, kwargs=module_kwargs, module.self_attn.k_proj,
)) module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict( layers.append(
prev_op=module.self_attn.v_proj, dict(
layers=[module.self_attn.o_proj], prev_op=module.self_attn.v_proj,
inp=input_feat['self_attn.o_proj'], layers=[module.self_attn.o_proj],
)) inp=input_feat["self_attn.o_proj"],
)
)
# linear 1 # linear 1
layers.append(dict( layers.append(
prev_op=module.post_attention_layernorm, dict(
layers=[module.mlp.gate_proj, module.mlp.up_proj], prev_op=module.post_attention_layernorm,
inp=input_feat['mlp.gate_proj'], layers=[module.mlp.gate_proj, module.mlp.up_proj],
module2inspect=module.mlp, inp=input_feat["mlp.gate_proj"],
)) module2inspect=module.mlp,
)
)
# linear 2 # linear 2
layers.append(dict( layers.append(
prev_op=module.mlp.up_proj, dict(
layers=[module.mlp.down_proj], prev_op=module.mlp.up_proj,
inp=input_feat['mlp.down_proj'], layers=[module.mlp.down_proj],
)) inp=input_feat["mlp.down_proj"],
)
)
return layers return layers
...@@ -78,10 +91,11 @@ class MistralFuser: ...@@ -78,10 +91,11 @@ class MistralFuser:
self.model = model self.model = model
self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [ self.mistral_blocks: List[Tuple[str, OldMistralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'MistralDecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "MistralDecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
blocks = [] blocks = []
...@@ -92,29 +106,30 @@ class MistralFuser: ...@@ -92,29 +106,30 @@ class MistralFuser:
module, module,
module.self_attn.q_proj, module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight, module.input_layernorm.variance_epsilon
module.input_layernorm.variance_epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon module.post_attention_layernorm.variance_epsilon,
) )
blocks.append(LlamaLikeBlock( blocks.append(
hidden_size=self.model.config.hidden_size, LlamaLikeBlock(
n_heads=self.model.config.num_attention_heads, hidden_size=self.model.config.hidden_size,
n_kv_heads=self.model.config.num_key_value_heads, n_heads=self.model.config.num_attention_heads,
qkv_layer=qkv, n_kv_heads=self.model.config.num_key_value_heads,
o_proj=module.self_attn.o_proj, qkv_layer=qkv,
mlp=module.mlp, o_proj=module.self_attn.o_proj,
norm_1=norm_1, mlp=module.mlp,
norm_2=norm_2, norm_1=norm_1,
dev=device, norm_2=norm_2,
max_seq_len=self.model.config.max_new_tokens dev=device,
)) max_seq_len=self.model.config.max_seq_len,
)
)
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
blocks, blocks,
......
import tqdm import tqdm
import torch
from typing import List, Tuple from typing import List, Tuple
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import MixtralBlock from awq.modules.fused.block import MixtralBlock
from awq.modules.fused.model import MixtralModel from awq.modules.fused.model import MixtralModel
from awq.modules.fused.moe import FusedSparseMoeBlock
from awq.utils.fused_utils import fuse_qkv, fuse_linears
from transformers.models.mixtral.modeling_mixtral import ( from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer, MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM MixtralForCausalLM as OldMixtralForCausalLM,
) )
from awq.modules.linear import WQLinear_GEMM
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class MixtralAWQForCausalLM(BaseAWQForCausalLM): class MixtralAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MixtralDecoderLayer" layer_type = "MixtralDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
modules_to_not_convert = ["gate"]
@staticmethod @staticmethod
def fuse_layers(model: OldMixtralForCausalLM): def fuse_layers(model: OldMixtralForCausalLM):
fuser = MixtralFuser(model) fuser = MixtralFuser(model)
fuser.fuse_transformer() fuser.fuse_transformer()
@staticmethod @staticmethod
def get_model_layers(model: OldMixtralForCausalLM): def get_model_layers(model: OldMixtralForCausalLM):
return model.model.layers return model.model.layers
@staticmethod @staticmethod
def get_act_for_scaling(module): def get_act_for_scaling(module):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model: OldMixtralForCausalLM, device: str): def move_embed(model: OldMixtralForCausalLM, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: OldMixtralDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(
module: OldMixtralDecoderLayer, input_feat, module_kwargs
):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
prev_op=module.input_layernorm, dict(
layers=[module.self_attn.q_proj, prev_op=module.input_layernorm,
module.self_attn.k_proj, module.self_attn.v_proj], layers=[
inp=input_feat['self_attn.q_proj'], module.self_attn.q_proj,
module2inspect=module.self_attn, kwargs=module_kwargs, module.self_attn.k_proj,
)) module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict( layers.append(
prev_op=module.self_attn.v_proj, dict(
layers=[module.self_attn.o_proj], prev_op=module.self_attn.v_proj,
inp=input_feat['self_attn.o_proj'], layers=[module.self_attn.o_proj],
)) inp=input_feat["self_attn.o_proj"],
)
)
# linear in # linear in
layers.append(dict( layers.append(
prev_op=module.post_attention_layernorm, dict(
layers=[ prev_op=module.post_attention_layernorm,
w for expert in module.block_sparse_moe.experts layers=[
for w in [expert.w1, expert.w3] w
], for expert in module.block_sparse_moe.experts
inp=input_feat['block_sparse_moe'], for w in [expert.w1, expert.w3]
module2inspect=module.block_sparse_moe, ],
)) inp=input_feat["block_sparse_moe"],
module2inspect=module.block_sparse_moe,
)
)
# linear out # linear out
for i, expert in enumerate(module.block_sparse_moe.experts): for i, expert in enumerate(module.block_sparse_moe.experts):
layers.append(dict( layers.append(
prev_op=expert.w3, dict(
layers=[expert.w2], prev_op=expert.w3,
inp=input_feat[f'block_sparse_moe.experts.{i}.w2'], layers=[expert.w2],
)) inp=input_feat[f"block_sparse_moe.experts.{i}.w2"],
)
)
return layers return layers
...@@ -81,49 +99,89 @@ class MixtralFuser: ...@@ -81,49 +99,89 @@ class MixtralFuser:
self.model = model self.model = model
self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [ self.mixtral_blocks: List[Tuple[str, OldMixtralDecoderLayer]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'MixtralDecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "MixtralDecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
blocks = [] blocks = []
module: OldMixtralDecoderLayer module: OldMixtralDecoderLayer
for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."): for module in tqdm.tqdm(self.model.model.layers, desc="Fusing layers..."):
device = next(iter(module.state_dict().values())).device device = next(iter(module.state_dict().values())).device
qkv = fuse_qkv( qkv = fuse_qkv(
module, module,
module.self_attn.q_proj, module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight, module.input_layernorm.variance_epsilon
module.input_layernorm.variance_epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon module.post_attention_layernorm.variance_epsilon,
)
sparse_moe = module.block_sparse_moe
if isinstance(sparse_moe.experts[0].w1, WQLinear_GEMM) and torch.cuda.device_count() == 1:
fused_w1w3s = [
fuse_linears(
[
sparse_moe.experts[i].w1,
sparse_moe.experts[i].w3,
],
device,
)
for i in range(len(sparse_moe.experts))
]
stacked_w1w3s = fuse_linears(
fused_w1w3s, device, dim=0, operation=torch.stack
)
stacked_w2s = fuse_linears(
[expert.w2 for expert in sparse_moe.experts],
device,
dim=0,
operation=torch.stack,
)
sparse_moe = FusedSparseMoeBlock(
top_k=sparse_moe.top_k,
gate=sparse_moe.gate,
ws=stacked_w1w3s,
w2s=stacked_w2s,
)
blocks.append(
MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
rope_theta=self.model.config.rope_theta,
)
) )
blocks.append(MixtralBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
moe=module.block_sparse_moe,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens,
rope_theta=self.model.config.rope_theta
))
model_norm = FasterTransformerRMSNorm(
self.model.model.norm.weight,
self.model.model.norm.variance_epsilon,
)
self.model.model = MixtralModel( self.model.model = MixtralModel(
self.model.config.vocab_size, self.model.config.vocab_size,
blocks, blocks,
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, model_norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks) setattr(self.model.model, "blocks", self.model.model.blocks)
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM from transformers.models.mpt.modeling_mpt import MptBlock as OldMptBlock, MptForCausalLM
class MptAWQForCausalLM(BaseAWQForCausalLM): class MptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "MPTBlock" layer_type = "MPTBlock"
max_new_tokens_key = "max_seq_len" max_seq_len_key = "max_seq_len"
@staticmethod @staticmethod
def fuse_layers(model: MptForCausalLM): def fuse_layers(model: MptForCausalLM):
...@@ -13,73 +14,84 @@ class MptAWQForCausalLM(BaseAWQForCausalLM): ...@@ -13,73 +14,84 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_model_layers(model: MptForCausalLM): def get_model_layers(model: MptForCausalLM):
return model.transformer.blocks return model.transformer.blocks
@staticmethod @staticmethod
def get_act_for_scaling(module: OldMptBlock): def get_act_for_scaling(module: OldMptBlock):
return dict( return dict(
is_scalable=True, is_scalable=True,
scale_name="ffn.act", scale_name="ffn.act",
scale_layer=module.ffn.act, scale_layer=module.ffn.act,
scale_shape=module.ffn.up_proj.out_features scale_shape=module.ffn.up_proj.out_features,
) )
@staticmethod @staticmethod
def move_embed(model: MptForCausalLM, device: str): def move_embed(model: MptForCausalLM, device: str):
model.transformer.wte = model.transformer.wte.to(device) model.transformer.wte = model.transformer.wte.to(device)
model.transformer.emb_drop = model.transformer.emb_drop.to(device) model.transformer.emb_drop = model.transformer.emb_drop.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs): def get_layers_for_scaling(module: OldMptBlock, input_feat, module_kwargs):
layers = [] layers = []
if module_kwargs.get("output_attentions") is not None: if module_kwargs.get("output_attentions") is not None:
module_kwargs.pop("output_attentions") module_kwargs.pop("output_attentions")
# attention input # attention input
layers.append(dict( layers.append(
prev_op=module.norm_1, dict(
layers=[module.attn.Wqkv], prev_op=module.norm_1,
inp=input_feat['attn.Wqkv'], layers=[module.attn.Wqkv],
module2inspect=module.attn, inp=input_feat["attn.Wqkv"],
kwargs=module_kwargs module2inspect=module.attn,
)) kwargs=module_kwargs,
)
)
# attention output # attention output
layers.append(dict( layers.append(
prev_op=module.attn.Wqkv, dict(
layers=[module.attn.out_proj], prev_op=module.attn.Wqkv,
inp=input_feat['attn.out_proj'] layers=[module.attn.out_proj],
)) inp=input_feat["attn.out_proj"],
)
)
# linear 1 # linear 1
layers.append(dict( layers.append(
prev_op=module.norm_2, dict(
layers=[module.ffn.up_proj], prev_op=module.norm_2,
inp=input_feat['ffn.up_proj'], layers=[module.ffn.up_proj],
module2inspect=module.ffn inp=input_feat["ffn.up_proj"],
)) module2inspect=module.ffn,
)
)
# linear 2 # linear 2
layers.append(dict( layers.append(
prev_op=module.ffn.act, dict(
layers=[module.ffn.down_proj], prev_op=module.ffn.act,
inp=input_feat['ffn.down_proj'] layers=[module.ffn.down_proj],
)) inp=input_feat["ffn.down_proj"],
)
)
return layers return layers
from typing import List, Tuple from typing import List, Tuple
from awq.utils.utils import set_module_name from awq.utils.utils import set_module_name
from awq.modules.fused.block import MPTBlock from awq.modules.fused.block import MPTBlock
from awq.modules.fused.model import MPTModel from awq.modules.fused.model import MPTModel
class MptFuser: class MptFuser:
def __init__(self, model: MptForCausalLM): def __init__(self, model: MptForCausalLM):
self.model = model self.model = model
self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [ self.mpt_blocks: List[Tuple[str, OldMptBlock]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'mptblock' in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "mptblock" in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
...@@ -87,17 +99,19 @@ class MptFuser: ...@@ -87,17 +99,19 @@ class MptFuser:
module: OldMptBlock module: OldMptBlock
for module in self.model.transformer.blocks: for module in self.model.transformer.blocks:
blocks.append(MPTBlock( blocks.append(
self.model.config.d_model, MPTBlock(
self.model.config.n_heads, self.model.config.d_model,
module.attn.Wqkv, self.model.config.n_heads,
module.attn.out_proj, module.attn.Wqkv,
module.ffn, module.attn.out_proj,
module.norm_1, module.ffn,
module.norm_2, module.norm_1,
next(iter(module.state_dict().values())).device, module.norm_2,
self.model.config.max_new_tokens next(iter(module.state_dict().values())).device,
)) self.model.config.max_seq_len,
)
)
self.model.transformer = MPTModel( self.model.transformer = MPTModel(
self.model.config.vocab_size, self.model.config.vocab_size,
...@@ -106,4 +120,4 @@ class MptFuser: ...@@ -106,4 +120,4 @@ class MptFuser:
self.model.transformer.norm_f, self.model.transformer.norm_f,
) )
setattr(self.model.transformer, "blocks", self.model.transformer.blocks) setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
\ No newline at end of file
from .base import BaseAWQForCausalLM from .base import BaseAWQForCausalLM
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTDecoderLayer
class OptAWQForCausalLM(BaseAWQForCausalLM): class OptAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "OPTDecoderLayer" layer_type = "OPTDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def get_model_layers(model: OPTForCausalLM): def get_model_layers(model: OPTForCausalLM):
return model.model.decoder.layers return model.model.decoder.layers
@staticmethod @staticmethod
def get_act_for_scaling(module: OPTDecoderLayer): def get_act_for_scaling(module: OPTDecoderLayer):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model: OPTForCausalLM, device: str): def move_embed(model: OPTForCausalLM, device: str):
model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device) model.model.decoder.embed_tokens = model.model.decoder.embed_tokens.to(device)
model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(device) model.model.decoder.embed_positions = model.model.decoder.embed_positions.to(
device
)
@staticmethod @staticmethod
def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs): def get_layers_for_scaling(module: OPTDecoderLayer, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
prev_op=module.self_attn_layer_norm, dict(
layers=[ prev_op=module.self_attn_layer_norm,
module.self_attn.q_proj, layers=[
module.self_attn.k_proj, module.self_attn.v_proj], module.self_attn.q_proj,
inp=input_feat['self_attn.q_proj'], module.self_attn.k_proj,
module2inspect=module.self_attn, module.self_attn.v_proj,
kwargs=module_kwargs, ],
)) inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
layers.append(dict( layers.append(
prev_op=module.self_attn.v_proj, dict(
layers=[module.self_attn.out_proj], prev_op=module.self_attn.v_proj,
inp=input_feat['self_attn.out_proj'], layers=[module.self_attn.out_proj],
)) inp=input_feat["self_attn.out_proj"],
)
)
# linear 1 # linear 1
layers.append(dict( layers.append(
prev_op=module.final_layer_norm, dict(
layers=[module.fc1], prev_op=module.final_layer_norm,
inp=input_feat['fc1'], layers=[module.fc1],
)) inp=input_feat["fc1"],
)
)
# linear 2 # linear 2
layers.append(dict( layers.append(
prev_op=module.fc1, dict(
layers=[module.fc2], prev_op=module.fc1,
inp=input_feat['fc2'], layers=[module.fc2],
)) inp=input_feat["fc2"],
)
)
return layers return layers
\ No newline at end of file
...@@ -3,7 +3,7 @@ from .base import BaseAWQForCausalLM ...@@ -3,7 +3,7 @@ from .base import BaseAWQForCausalLM
class QwenAWQForCausalLM(BaseAWQForCausalLM): class QwenAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "QWenBlock" layer_type = "QWenBlock"
max_new_tokens_key = "seq_length" max_seq_len_key = "seq_length"
@staticmethod @staticmethod
def get_model_layers(model): def get_model_layers(model):
......
...@@ -6,14 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock ...@@ -6,14 +6,14 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from transformers.models.qwen2.modeling_qwen2 import ( from transformers.models.qwen2.modeling_qwen2 import (
Qwen2DecoderLayer as OldQwen2DecoderLayer, Qwen2DecoderLayer as OldQwen2DecoderLayer,
Qwen2ForCausalLM as OldQwen2ForCausalLM Qwen2ForCausalLM as OldQwen2ForCausalLM,
) )
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class Qwen2AWQForCausalLM(BaseAWQForCausalLM): class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
layer_type = "Qwen2DecoderLayer" layer_type = "Qwen2DecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model: OldQwen2ForCausalLM): def fuse_layers(model: OldQwen2ForCausalLM):
...@@ -26,9 +26,7 @@ class Qwen2AWQForCausalLM(BaseAWQForCausalLM): ...@@ -26,9 +26,7 @@ class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_act_for_scaling(module: OldQwen2DecoderLayer): def get_act_for_scaling(module: OldQwen2DecoderLayer):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model: OldQwen2ForCausalLM, device: str): def move_embed(model: OldQwen2ForCausalLM, device: str):
...@@ -39,37 +37,49 @@ class Qwen2AWQForCausalLM(BaseAWQForCausalLM): ...@@ -39,37 +37,49 @@ class Qwen2AWQForCausalLM(BaseAWQForCausalLM):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
prev_op=module.input_layernorm, dict(
layers=[module.self_attn.q_proj, prev_op=module.input_layernorm,
module.self_attn.k_proj, module.self_attn.v_proj], layers=[
inp=input_feat['self_attn.q_proj'], module.self_attn.q_proj,
module2inspect=module.self_attn, kwargs=module_kwargs, module.self_attn.k_proj,
)) module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict( layers.append(
prev_op=module.self_attn.v_proj, dict(
layers=[module.self_attn.o_proj], prev_op=module.self_attn.v_proj,
inp=input_feat['self_attn.o_proj'], layers=[module.self_attn.o_proj],
)) inp=input_feat["self_attn.o_proj"],
)
)
# linear 1 # linear 1
layers.append(dict( layers.append(
prev_op=module.post_attention_layernorm, dict(
layers=[module.mlp.gate_proj, module.mlp.up_proj], prev_op=module.post_attention_layernorm,
inp=input_feat['mlp.gate_proj'], layers=[module.mlp.gate_proj, module.mlp.up_proj],
module2inspect=module.mlp, inp=input_feat["mlp.gate_proj"],
)) module2inspect=module.mlp,
)
)
# linear 2 # linear 2
layers.append(dict( layers.append(
prev_op=module.mlp.up_proj, dict(
layers=[module.mlp.down_proj], prev_op=module.mlp.up_proj,
inp=input_feat['mlp.down_proj'], layers=[module.mlp.down_proj],
)) inp=input_feat["mlp.down_proj"],
)
)
return layers return layers
...@@ -79,8 +89,9 @@ class Qwen2Fuser: ...@@ -79,8 +89,9 @@ class Qwen2Fuser:
self.model = model self.model = model
self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [ self.qwen2_blocks: List[Tuple[str, OldQwen2DecoderLayer]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'Qwen2DecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "Qwen2DecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
...@@ -93,28 +104,29 @@ class Qwen2Fuser: ...@@ -93,28 +104,29 @@ class Qwen2Fuser:
module, module,
module.self_attn.q_proj, module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight, module.input_layernorm.variance_epsilon
module.input_layernorm.variance_epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.post_attention_layernorm.weight, module.post_attention_layernorm.weight,
module.post_attention_layernorm.variance_epsilon module.post_attention_layernorm.variance_epsilon,
)
blocks.append(
LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_seq_len,
)
) )
blocks.append(LlamaLikeBlock(
hidden_size=self.model.config.hidden_size,
n_heads=self.model.config.num_attention_heads,
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
......
...@@ -6,9 +6,10 @@ from awq.modules.fused.block import LlamaLikeBlock ...@@ -6,9 +6,10 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class YiAWQForCausalLM(BaseAWQForCausalLM): class YiAWQForCausalLM(BaseAWQForCausalLM):
layer_type = "YiDecoderLayer" layer_type = "YiDecoderLayer"
max_new_tokens_key = "max_position_embeddings" max_seq_len_key = "max_position_embeddings"
@staticmethod @staticmethod
def fuse_layers(model): def fuse_layers(model):
...@@ -18,53 +19,63 @@ class YiAWQForCausalLM(BaseAWQForCausalLM): ...@@ -18,53 +19,63 @@ class YiAWQForCausalLM(BaseAWQForCausalLM):
@staticmethod @staticmethod
def get_model_layers(model): def get_model_layers(model):
return model.model.layers return model.model.layers
@staticmethod @staticmethod
def get_act_for_scaling(module): def get_act_for_scaling(module):
return dict( return dict(is_scalable=False)
is_scalable=False
)
@staticmethod @staticmethod
def move_embed(model, device: str): def move_embed(model, device: str):
model.model.embed_tokens = model.model.embed_tokens.to(device) model.model.embed_tokens = model.model.embed_tokens.to(device)
@staticmethod @staticmethod
def get_layers_for_scaling(module, input_feat, module_kwargs): def get_layers_for_scaling(module, input_feat, module_kwargs):
layers = [] layers = []
# attention input # attention input
layers.append(dict( layers.append(
prev_op=module.ln1, dict(
layers=[module.self_attn.q_proj, prev_op=module.ln1,
module.self_attn.k_proj, module.self_attn.v_proj], layers=[
inp=input_feat['self_attn.q_proj'], module.self_attn.q_proj,
module2inspect=module.self_attn, kwargs=module_kwargs, module.self_attn.k_proj,
)) module.self_attn.v_proj,
],
inp=input_feat["self_attn.q_proj"],
module2inspect=module.self_attn,
kwargs=module_kwargs,
)
)
# attention out # attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696 # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape: if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
layers.append(dict( layers.append(
prev_op=module.self_attn.v_proj, dict(
layers=[module.self_attn.o_proj], prev_op=module.self_attn.v_proj,
inp=input_feat['self_attn.o_proj'], layers=[module.self_attn.o_proj],
)) inp=input_feat["self_attn.o_proj"],
)
)
# linear 1 # linear 1
layers.append(dict( layers.append(
prev_op=module.ln2, dict(
layers=[module.mlp.gate_proj, module.mlp.up_proj], prev_op=module.ln2,
inp=input_feat['mlp.gate_proj'], layers=[module.mlp.gate_proj, module.mlp.up_proj],
module2inspect=module.mlp, inp=input_feat["mlp.gate_proj"],
)) module2inspect=module.mlp,
)
)
# linear 2 # linear 2
layers.append(dict( layers.append(
prev_op=module.mlp.up_proj, dict(
layers=[module.mlp.down_proj], prev_op=module.mlp.up_proj,
inp=input_feat['mlp.down_proj'], layers=[module.mlp.down_proj],
)) inp=input_feat["mlp.down_proj"],
)
)
return layers return layers
...@@ -74,10 +85,11 @@ class YiFuser: ...@@ -74,10 +85,11 @@ class YiFuser:
self.model = model self.model = model
self.yi_blocks: List[Tuple[str, object]] = [ self.yi_blocks: List[Tuple[str, object]] = [
(name, module) for name, module in self.model.named_modules() (name, module)
if 'YiDecoderLayer'.lower() in module.__class__.__name__.lower() for name, module in self.model.named_modules()
if "YiDecoderLayer".lower() in module.__class__.__name__.lower()
] ]
def fuse_transformer(self): def fuse_transformer(self):
blocks = [] blocks = []
...@@ -87,30 +99,30 @@ class YiFuser: ...@@ -87,30 +99,30 @@ class YiFuser:
module, module,
module.self_attn.q_proj, module.self_attn.q_proj,
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj,
) )
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.ln1.weight, module.ln1.weight, module.ln1.variance_epsilon
module.ln1.variance_epsilon
) )
norm_2 = FasterTransformerRMSNorm( norm_2 = FasterTransformerRMSNorm(
module.ln2.weight, module.ln2.weight, module.ln2.variance_epsilon
module.ln2.variance_epsilon
) )
blocks.append(LlamaLikeBlock( blocks.append(
hidden_size=self.model.config.hidden_size, LlamaLikeBlock(
n_heads=self.model.config.num_attention_heads, hidden_size=self.model.config.hidden_size,
n_kv_heads=self.model.config.num_key_value_heads, n_heads=self.model.config.num_attention_heads,
qkv_layer=qkv, n_kv_heads=self.model.config.num_key_value_heads,
o_proj=module.self_attn.o_proj, qkv_layer=qkv,
mlp=module.mlp, o_proj=module.self_attn.o_proj,
norm_1=norm_1, mlp=module.mlp,
norm_2=norm_2, norm_1=norm_1,
dev=device, norm_2=norm_2,
max_seq_len=self.model.config.max_new_tokens, dev=device,
rope_theta=self.model.config.rope_theta max_seq_len=self.model.config.max_seq_len,
)) rope_theta=self.model.config.rope_theta,
)
)
self.model.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
blocks, blocks,
......
import torch.nn as nn import torch.nn as nn
class ScaledActivation(nn.Module): class ScaledActivation(nn.Module):
def __init__(self, module, scales): def __init__(self, module, scales):
super().__init__() super().__init__()
self.act = module self.act = module
self.scales = nn.Parameter(scales.data) self.scales = nn.Parameter(scales.data)
def forward(self, x): def forward(self, x):
return self.act(x) / self.scales.view(1, 1, -1).to(x.device) return self.act(x) / self.scales.view(1, 1, -1).to(x.device)
...@@ -9,6 +9,7 @@ from awq.utils.fused_utils import get_attention_shapes ...@@ -9,6 +9,7 @@ from awq.utils.fused_utils import get_attention_shapes
try: try:
import awq_ft_ext import awq_ft_ext
FT_INSTALLED = True FT_INSTALLED = True
except: except:
FT_INSTALLED = False FT_INSTALLED = False
...@@ -16,6 +17,7 @@ except: ...@@ -16,6 +17,7 @@ except:
HF_NEW_CACHE_FORMAT = False HF_NEW_CACHE_FORMAT = False
import transformers import transformers
# https://github.com/huggingface/transformers/pull/26681 introduced a new cache format # https://github.com/huggingface/transformers/pull/26681 introduced a new cache format
HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils") HF_NEW_CACHE_FORMAT = hasattr(transformers, "cache_utils")
if HF_NEW_CACHE_FORMAT: if HF_NEW_CACHE_FORMAT:
...@@ -25,12 +27,12 @@ if HF_NEW_CACHE_FORMAT: ...@@ -25,12 +27,12 @@ if HF_NEW_CACHE_FORMAT:
class RoPE(nn.Module): class RoPE(nn.Module):
def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta): def __init__(self, hidden_size, n_heads, max_seq_len, device, rope_theta):
super(RoPE, self).__init__() super(RoPE, self).__init__()
self.freqs_cis = nn.Parameter( self.freqs_cis = nn.Parameter(
self.precompute_freqs_cis( self.precompute_freqs_cis(
hidden_size // n_heads, max_seq_len * 2, rope_theta hidden_size // n_heads, max_seq_len * 2, rope_theta
).to(device), ).to(device),
requires_grad=False requires_grad=False,
) )
@staticmethod @staticmethod
...@@ -58,18 +60,21 @@ class RoPE(nn.Module): ...@@ -58,18 +60,21 @@ class RoPE(nn.Module):
) )
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device) freqs_cis = self.reshape_for_broadcast(freqs_cis, xq_).to(xq_.device)
xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3) xq_out = torch.view_as_real(xq_ * freqs_cis).transpose(-2, -1).flatten(3)
xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3) xk_out = torch.view_as_real(xk_ * freqs_cis).transpose(-2, -1).flatten(3)
return xq_out.type_as(xq), xk_out.type_as(xk) return xq_out.type_as(xq), xk_out.type_as(xk)
class ALiBi(nn.Module): class ALiBi(nn.Module):
def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8): def __init__(self, n_heads, max_seq_len, device, alibi_bias_max=8):
super(ALiBi, self).__init__() super(ALiBi, self).__init__()
# Initialize ALiBi slopes and bias # Initialize ALiBi slopes and bias
slopes, bias = self.build_alibi_bias(n_heads, max_seq_len, alibi_bias_max=alibi_bias_max) slopes, bias = self.build_alibi_bias(
n_heads, max_seq_len, alibi_bias_max=alibi_bias_max
)
self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False) self.slopes = nn.Parameter(slopes.float().to(device), requires_grad=False)
self.bias = nn.Parameter(bias.float().to(device), requires_grad=False) self.bias = nn.Parameter(bias.float().to(device), requires_grad=False)
...@@ -79,27 +84,42 @@ class ALiBi(nn.Module): ...@@ -79,27 +84,42 @@ class ALiBi(nn.Module):
m = torch.arange(1, _n_heads + 1, dtype=torch.float32) m = torch.arange(1, _n_heads + 1, dtype=torch.float32)
m = m.mul(alibi_bias_max / _n_heads) m = m.mul(alibi_bias_max / _n_heads)
slopes = 1.0 / torch.pow(2, m) slopes = 1.0 / torch.pow(2, m)
if _n_heads != n_heads: if _n_heads != n_heads:
slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads] slopes = torch.cat([slopes[1::2], slopes[::2]])[:n_heads]
return slopes.view(1, n_heads, 1, 1) return slopes.view(1, n_heads, 1, 1)
@staticmethod @staticmethod
def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32): def build_alibi_bias(n_heads, seq_len, alibi_bias_max=8, dtype=torch.float32):
alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(1, 1, 1, seq_len) alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.int32).view(
1, 1, 1, seq_len
)
slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max) slopes = ALiBi.gen_slopes(n_heads, alibi_bias_max)
alibi_bias = alibi_bias * slopes alibi_bias = alibi_bias * slopes
slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1) slopes = slopes.squeeze(0).squeeze(-1).squeeze(-1)
return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype) return slopes.to(dtype=dtype), alibi_bias.to(dtype=dtype)
def forward(self, scores, seqlen): def forward(self, scores, seqlen):
scores += self.bias[..., :seqlen] scores += self.bias[..., :seqlen]
return scores return scores
class QuantAttentionFused(nn.Module): class QuantAttentionFused(nn.Module):
def __init__(self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, dev, max_seq_len, def __init__(
use_alibi=False, attention_shapes=None, rope_theta=10000): self,
hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
dev,
max_seq_len=2048,
use_alibi=False,
attention_shapes=None,
rope_theta=10000,
**kwargs
):
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.n_heads = n_heads self.n_heads = n_heads
...@@ -111,17 +131,29 @@ class QuantAttentionFused(nn.Module): ...@@ -111,17 +131,29 @@ class QuantAttentionFused(nn.Module):
self.start_pos = 0 self.start_pos = 0
self.use_alibi = use_alibi self.use_alibi = use_alibi
self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) self.cache_batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
if kwargs.get("max_new_tokens") is not None:
max_seq_len = kwargs["max_new_tokens"]
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
self.is_hf_transformers = False self.is_hf_transformers = False
self.rope_theta = rope_theta self.rope_theta = rope_theta
# attention shapes for self attention # attention shapes for self attention
self.attention_shapes = get_attention_shapes( self.attention_shapes = get_attention_shapes(
attention_shapes, max_seq_len, self.cache_batch_size, n_heads, n_kv_heads, self.head_dim attention_shapes,
max_seq_len,
self.cache_batch_size,
n_heads,
n_kv_heads,
self.head_dim,
) )
# cache store that rolls cache # cache store that rolls cache
self.cache = WindowedCache( self.cache = WindowedCache(
self.attention_shapes["cache_v"], self.attention_shapes["cache_k"], self.max_seq_len, dev self.attention_shapes["cache_v"],
self.attention_shapes["cache_k"],
self.max_seq_len,
dev,
) )
if use_alibi: if use_alibi:
...@@ -133,8 +165,10 @@ class QuantAttentionFused(nn.Module): ...@@ -133,8 +165,10 @@ class QuantAttentionFused(nn.Module):
self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta) self.rope = RoPE(hidden_size, n_heads, max_seq_len, dev, rope_theta)
self.rotary_dim = self.head_dim self.rotary_dim = self.head_dim
self.is_neox = True self.is_neox = True
def forward(self, hidden_states:torch.Tensor, attention_mask=None, *args, **kwargs): def forward(
self, hidden_states: torch.Tensor, attention_mask=None, *args, **kwargs
):
bsz, seqlen, _ = hidden_states.shape bsz, seqlen, _ = hidden_states.shape
# Reallocate cache if batch size changes # Reallocate cache if batch size changes
...@@ -147,18 +181,22 @@ class QuantAttentionFused(nn.Module): ...@@ -147,18 +181,22 @@ class QuantAttentionFused(nn.Module):
self.cache_batch_size = bsz self.cache_batch_size = bsz
# Always reset to 0 # Always reset to 0
self.start_pos = 0 self.start_pos = 0
# In case we re-generate, we need to refresh the starting position # In case we re-generate, we need to refresh the starting position
# to 0. We detect it by checking if `past_key_values` is set to None, # to 0. We detect it by checking if `past_key_values` is set to None,
# which indicates that we are on the first step of `generate()`. # which indicates that we are on the first step of `generate()`.
# This is only applicable for `transformers` integration # This is only applicable for `transformers` integration
if self.is_hf_transformers and "past_key_value" in kwargs and kwargs["past_key_value"] is None: if (
self.is_hf_transformers
and "past_key_value" in kwargs
and kwargs["past_key_value"] is None
):
self.start_pos = 0 self.start_pos = 0
xqkv = self.qkv_proj(hidden_states) xqkv = self.qkv_proj(hidden_states)
xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"]) xqkv = xqkv.view((bsz, seqlen) + self.attention_shapes["xqkv_view"])
xq = self.attention_shapes["xq_slice"](xqkv) xq = self.attention_shapes["xq_slice"](xqkv)
xk = self.attention_shapes["xk_slice"](xqkv) xk = self.attention_shapes["xk_slice"](xqkv)
xv = self.attention_shapes["xv_slice"](xqkv) xv = self.attention_shapes["xv_slice"](xqkv)
...@@ -179,21 +217,22 @@ class QuantAttentionFused(nn.Module): ...@@ -179,21 +217,22 @@ class QuantAttentionFused(nn.Module):
.permute(0, 2, 3, 1, 4) .permute(0, 2, 3, 1, 4)
.contiguous() .contiguous()
) )
self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen) self.cache.update_kv(values_store, keys_store, bsz, self.start_pos, seqlen)
# Only necessary to retrieve from cache when we are not processing context # Only necessary to retrieve from cache when we are not processing context
if seqlen == 1: if seqlen == 1:
xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim) xv, xk = self.cache.get_kv(bsz, self.start_pos, seqlen, self.head_dim)
keys = xk keys = xk
values = xv values = xv
if self.n_kv_groups != 0: if self.n_kv_groups != 0:
keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups) keys = torch.repeat_interleave(keys, dim=2, repeats=self.n_kv_groups)
values = torch.repeat_interleave(values, dim=2, repeats=self.n_kv_groups) values = torch.repeat_interleave(
values, dim=2, repeats=self.n_kv_groups
)
xq = xq.transpose(1, 2) xq = xq.transpose(1, 2)
keys = keys.transpose(1, 2) keys = keys.transpose(1, 2)
values = values.transpose(1, 2) values = values.transpose(1, 2)
...@@ -204,7 +243,9 @@ class QuantAttentionFused(nn.Module): ...@@ -204,7 +243,9 @@ class QuantAttentionFused(nn.Module):
# When seqlen is 1, there is nothing else to attend to # When seqlen is 1, there is nothing else to attend to
if attention_mask is not None and seqlen > 1: if attention_mask is not None and seqlen > 1:
scores = scores + attention_mask # (bs, n_local_heads, slen, cache_len + slen) scores = (
scores + attention_mask
) # (bs, n_local_heads, slen, cache_len + slen)
scores = F.softmax(scores.float(), dim=-1).type_as(xq) scores = F.softmax(scores.float(), dim=-1).type_as(xq)
output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim) output = torch.matmul(scores, values) # (bs, n_local_heads, slen, head_dim)
attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) attention_weight = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
...@@ -215,25 +256,25 @@ class QuantAttentionFused(nn.Module): ...@@ -215,25 +256,25 @@ class QuantAttentionFused(nn.Module):
alibi_slopes = self.alibi.slopes if self.alibi is not None else None alibi_slopes = self.alibi.slopes if self.alibi is not None else None
attention_weight = awq_ft_ext.single_query_attention( attention_weight = awq_ft_ext.single_query_attention(
xq, # query xq, # query
xk, # key xk, # key
xv, # value xv, # value
self.cache.k, # key cache self.cache.k, # key cache
self.cache.v, # value cache self.cache.v, # value cache
None, # length per sample None, # length per sample
alibi_slopes, # alibi slopes alibi_slopes, # alibi slopes
self.start_pos, # timestep self.start_pos, # timestep
self.rotary_dim, # rotary embedding dimension self.rotary_dim, # rotary embedding dimension
self.rope_theta, # rotary embedding base self.rope_theta, # rotary embedding base
self.is_neox, # is neox self.is_neox, # is neox
) )
attention_weight = attention_weight.reshape(bsz, 1, -1) attention_weight = attention_weight.reshape(bsz, 1, -1)
attn_output = self.o_proj(attention_weight) attn_output = self.o_proj(attention_weight)
self.start_pos += seqlen self.start_pos += seqlen
# past_key_value is replaced with cache_v, cache_k, returning empty data # past_key_value is replaced with cache_v, cache_k, returning empty data
# we pass a dummy past kv cache for transformers to be able to retrieve the correct info # we pass a dummy past kv cache for transformers to be able to retrieve the correct info
# about past key length # about past key length
past_key_value = [torch.zeros(1, 1, self.start_pos, 1)] past_key_value = [torch.zeros(1, 1, self.start_pos, 1)]
......
...@@ -2,10 +2,21 @@ import os ...@@ -2,10 +2,21 @@ import os
import torch.nn as nn import torch.nn as nn
from awq.modules.fused.attn import QuantAttentionFused from awq.modules.fused.attn import QuantAttentionFused
class MixtralBlock(nn.Module): class MixtralBlock(nn.Module):
def __init__( def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, self,
moe, norm_1, norm_2, dev, max_seq_len, rope_theta hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
moe,
norm_1,
norm_2,
dev,
max_seq_len,
rope_theta,
): ):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
...@@ -13,37 +24,62 @@ class MixtralBlock(nn.Module): ...@@ -13,37 +24,62 @@ class MixtralBlock(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev) self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, self.hidden_size,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, rope_theta=rope_theta self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
rope_theta=rope_theta,
).to(dev) ).to(dev)
self.norm_2 = norm_2.to(dev) self.norm_2 = norm_2.to(dev)
self.moe = moe self.moe = moe
self.device = dev self.device = dev
def forward( def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
): ):
norm_out = self.norm_1(hidden_states) norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward( attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out, hidden_states=norm_out,
past_key_value=past_key_value, past_key_value=past_key_value,
attention_mask=attention_mask attention_mask=attention_mask,
) )
h = hidden_states.to(attn_output.device) + attn_output h = hidden_states.to(attn_output.device) + attn_output
out, _ = self.moe.forward(self.norm_2(h)) out = self.moe.forward(self.norm_2(h))
out = h + out out = h + out
return out, None, past_key_value return out, None, past_key_value
class LlamaLikeBlock(nn.Module): class LlamaLikeBlock(nn.Module):
""" """
LlamaLikeBlock is intended to be reused across blocks that have LlamaLikeBlock is intended to be reused across blocks that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila. an architecture that closely resembles Llama, e.g. Mistral and Aquila.
""" """
def __init__( def __init__(
self, hidden_size, n_heads, n_kv_heads, qkv_layer, o_proj, self,
mlp, norm_1, norm_2, dev, max_seq_len, rope_theta=10000, use_alibi=False hidden_size,
n_heads,
n_kv_heads,
qkv_layer,
o_proj,
mlp,
norm_1,
norm_2,
dev,
max_seq_len,
rope_theta=10000,
use_alibi=False,
): ):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
...@@ -51,21 +87,33 @@ class LlamaLikeBlock(nn.Module): ...@@ -51,21 +87,33 @@ class LlamaLikeBlock(nn.Module):
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.norm_1 = norm_1.to(dev) self.norm_1 = norm_1.to(dev)
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
self.hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, self.hidden_size,
dev=dev, max_seq_len=max_seq_len, use_alibi=use_alibi, rope_theta=rope_theta self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=use_alibi,
rope_theta=rope_theta,
).to(dev) ).to(dev)
self.norm_2 = norm_2.to(dev) self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev) self.mlp = mlp.to(dev)
self.device = dev self.device = dev
def forward( def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
): ):
norm_out = self.norm_1(hidden_states) norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward( attn_output, _, past_key_value = self.attn.forward(
hidden_states=norm_out, hidden_states=norm_out,
past_key_value=past_key_value, past_key_value=past_key_value,
attention_mask=attention_mask attention_mask=attention_mask,
) )
h = hidden_states.to(attn_output.device) + attn_output h = hidden_states.to(attn_output.device) + attn_output
...@@ -73,23 +121,46 @@ class LlamaLikeBlock(nn.Module): ...@@ -73,23 +121,46 @@ class LlamaLikeBlock(nn.Module):
return out, None, past_key_value return out, None, past_key_value
class MPTBlock(nn.Module): class MPTBlock(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mpt_mlp, norm_1, norm_2, dev, max_seq_len): def __init__(
self,
hidden_size,
n_heads,
qkv_layer,
o_proj,
mpt_mlp,
norm_1,
norm_2,
dev,
max_seq_len,
):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 0 self.n_kv_heads = 0
self.hidden_size = hidden_size self.hidden_size = hidden_size
self.norm_1 = norm_1 self.norm_1 = norm_1
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, hidden_size,
dev=dev, max_seq_len=max_seq_len, use_alibi=True self.n_heads,
self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=True,
).to(dev) ).to(dev)
self.norm_2 = norm_2 self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev) self.ffn = mpt_mlp.to(dev)
self.device = dev self.device = dev
def forward( def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
): ):
norm_out = self.norm_1(hidden_states) norm_out = self.norm_1(hidden_states)
attn_output, _, past_key_value = self.attn.forward( attn_output, _, past_key_value = self.attn.forward(
...@@ -98,16 +169,29 @@ class MPTBlock(nn.Module): ...@@ -98,16 +169,29 @@ class MPTBlock(nn.Module):
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=None, position_ids=None,
output_attentions=False, output_attentions=False,
use_cache=True use_cache=True,
) )
h = hidden_states.to(attn_output.device) + attn_output h = hidden_states.to(attn_output.device) + attn_output
out = h + self.ffn.forward(self.norm_2(h)) out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value return out, None, past_key_value
class FalconDecoderLayer(nn.Module): class FalconDecoderLayer(nn.Module):
def __init__(self, hidden_size, n_heads, qkv_layer, o_proj, mlp, dev, max_seq_len, def __init__(
input_layernorm=None, ln_attn=None, ln_mlp=None, new_decoder_arch=True): self,
hidden_size,
n_heads,
qkv_layer,
o_proj,
mlp,
dev,
max_seq_len,
input_layernorm=None,
ln_attn=None,
ln_mlp=None,
new_decoder_arch=True,
):
super().__init__() super().__init__()
self.n_heads = n_heads self.n_heads = n_heads
self.n_kv_heads = 8 if new_decoder_arch else 0 self.n_kv_heads = 8 if new_decoder_arch else 0
...@@ -117,33 +201,52 @@ class FalconDecoderLayer(nn.Module): ...@@ -117,33 +201,52 @@ class FalconDecoderLayer(nn.Module):
if new_decoder_arch: if new_decoder_arch:
attention_shapes = None attention_shapes = None
else: else:
attention_shapes = self._get_attention_shapes(n_heads, max_seq_len, self.hidden_size // n_heads) attention_shapes = self._get_attention_shapes(
n_heads, max_seq_len, self.hidden_size // n_heads
)
# TODO: Falcon has ALiBi implemented but which model uses it? # TODO: Falcon has ALiBi implemented but which model uses it?
self.attn = QuantAttentionFused( self.attn = QuantAttentionFused(
hidden_size, self.n_heads, self.n_kv_heads, qkv_layer, o_proj, hidden_size,
dev=dev, max_seq_len=max_seq_len, use_alibi=False, self.n_heads,
attention_shapes=attention_shapes self.n_kv_heads,
qkv_layer,
o_proj,
dev=dev,
max_seq_len=max_seq_len,
use_alibi=False,
attention_shapes=attention_shapes,
).to(dev) ).to(dev)
if new_decoder_arch: if new_decoder_arch:
self.ln_attn = ln_attn # before attention self.ln_attn = ln_attn # before attention
self.ln_mlp = ln_mlp # before mlp self.ln_mlp = ln_mlp # before mlp
else: else:
self.input_layernorm = input_layernorm # before attention self.input_layernorm = input_layernorm # before attention
self.mlp = mlp self.mlp = mlp
self.device = dev self.device = dev
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim): def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
self.attention_shapes = { self.attention_shapes = {
# following fastertransformer definition # following fastertransformer definition
"cache_v": (batch_size, 1, max_seq_len, head_dim,), "cache_v": (
batch_size,
1,
max_seq_len,
head_dim,
),
# 8: pack 8 fp16 in FT, if fp32 then use 4 # 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (batch_size, 1, head_dim // 8, max_seq_len, 8,), "cache_k": (
"xqkv_view": (n_heads+2, head_dim), batch_size,
1,
head_dim // 8,
max_seq_len,
8,
),
"xqkv_view": (n_heads + 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, :-2], "xq_slice": lambda xqkv: xqkv[:, :, :-2],
"xk_slice": lambda xqkv: xqkv[:, :, [-2]], "xk_slice": lambda xqkv: xqkv[:, :, [-2]],
"xv_slice": lambda xqkv: xqkv[:, :, [-1]], "xv_slice": lambda xqkv: xqkv[:, :, [-1]],
...@@ -153,27 +256,32 @@ class FalconDecoderLayer(nn.Module): ...@@ -153,27 +256,32 @@ class FalconDecoderLayer(nn.Module):
"xk_reshape": (1, head_dim // 8, 8), "xk_reshape": (1, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim), "single_xq_view": (n_heads, head_dim),
"single_xk_view": (1, head_dim), "single_xk_view": (1, head_dim),
"single_xv_view": (1, head_dim) "single_xv_view": (1, head_dim),
} }
return self.attention_shapes return self.attention_shapes
def forward( def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None self,
hidden_states,
past_key_value,
attn_bias=None,
attention_mask=None,
is_causal=None,
): ):
if self.new_decoder_arch: if self.new_decoder_arch:
layernorm_out = self.ln_attn(hidden_states) layernorm_out = self.ln_attn(hidden_states)
mlp_layernorm_out = self.ln_mlp(hidden_states) mlp_layernorm_out = self.ln_mlp(hidden_states)
else: else:
layernorm_out = self.input_layernorm(hidden_states) layernorm_out = self.input_layernorm(hidden_states)
attn_output, _, past_key_value = self.attn.forward( attn_output, _, past_key_value = self.attn.forward(
hidden_states=layernorm_out, hidden_states=layernorm_out,
past_key_value=past_key_value, past_key_value=past_key_value,
attention_mask=attention_mask, attention_mask=attention_mask,
position_ids=None, position_ids=None,
output_attentions=False, output_attentions=False,
use_cache=True use_cache=True,
) )
h_attn = hidden_states.to(attn_output.device) + attn_output h_attn = hidden_states.to(attn_output.device) + attn_output
...@@ -182,7 +290,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -182,7 +290,7 @@ class FalconDecoderLayer(nn.Module):
h_mlp = self.mlp.forward(mlp_layernorm_out) h_mlp = self.mlp.forward(mlp_layernorm_out)
else: else:
h_mlp = self.mlp.forward(layernorm_out) h_mlp = self.mlp.forward(layernorm_out)
out = h_attn + h_mlp out = h_attn + h_mlp
return out, None, past_key_value return out, None, past_key_value
\ No newline at end of file
import torch import torch
class WindowedCache: class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device): def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device):
""" """
The window size is the same as the max_new_tokens. The window will The window size is the same as the max_seq_len. The window will
automatically roll once max_new_tokens is exceeded. automatically roll once max_seq_len is exceeded.
""" """
# [batch_size, n_kv_heads, max_seq_len, head_dim] # [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half() self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor] # [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half() self.k = torch.zeros(cache_k_shape).to(device).half()
self.max_seq_len = max_seq_len self.max_seq_len = max_seq_len
def get_kv(self, batch_size, start_pos, seqlen, head_dim): def get_kv(self, batch_size, start_pos, seqlen, head_dim):
""" """
Gets the key-value store in correct shapes. Gets the key-value store in correct shapes.
""" """
xv = self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous() xv = (
xk = self.k[:batch_size, :, :, : start_pos + seqlen, :].transpose(2, 3).contiguous() self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
)
xk = (
self.k[:batch_size, :, :, : start_pos + seqlen, :]
.transpose(2, 3)
.contiguous()
)
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous() xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
return xv, xk return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen): def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
""" """
Updates the values in the key-value store. Updates the values in the key-value store.
...@@ -41,19 +48,23 @@ class WindowedCache: ...@@ -41,19 +48,23 @@ class WindowedCache:
# Zero out the new part # Zero out the new part
self.v[:, :, -n:, :] = 0 self.v[:, :, -n:, :] = 0
self.k[:, :, :, -n:, :] = 0 self.k[:, :, :, -n:, :] = 0
return start_pos - n return start_pos - n
def to(self, device): def to(self, device):
self.k = self.k.to(device) self.k = self.k.to(device)
self.v = self.v.to(device) self.v = self.v.to(device)
def increase_batch_size(self, to_bsz): def increase_batch_size(self, to_bsz):
"""Dynamically allocate new kv when batch size changes.""" """Dynamically allocate new kv when batch size changes."""
self.v = torch.zeros(to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device) self.v = torch.zeros(
self.k = torch.zeros(to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device) to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device
)
self.k = torch.zeros(
to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device
)
def decrease_batch_size(self, to_bsz): def decrease_batch_size(self, to_bsz):
"""Dynamically remove part of cache if batch size changes.""" """Dynamically remove part of cache if batch size changes."""
self.v = self.v[:to_bsz, :, :, :] self.v = self.v[:to_bsz, :, :, :]
self.k = self.k[:to_bsz, :, :, :, :] self.k = self.k[:to_bsz, :, :, :, :]
\ No newline at end of file
...@@ -5,26 +5,28 @@ from awq.modules.linear.gemv import WQLinear_GEMV ...@@ -5,26 +5,28 @@ from awq.modules.linear.gemv import WQLinear_GEMV
try: try:
import awq_ext # with CUDA kernels import awq_ext # with CUDA kernels
AWQ_INSTALLED = True AWQ_INSTALLED = True
except: except:
AWQ_INSTALLED = False AWQ_INSTALLED = False
class QuantFusedMLP(nn.Module): class QuantFusedMLP(nn.Module):
def __init__( def __init__(
self, self,
gate_proj, gate_proj,
down_proj, down_proj,
up_proj, up_proj,
activation = F.silu, activation=F.silu,
): ):
super().__init__() super().__init__()
self.register_buffer('gate_proj_qweight', gate_proj.qweight) self.register_buffer("gate_proj_qweight", gate_proj.qweight)
self.register_buffer('gate_proj_scales', gate_proj.scales) self.register_buffer("gate_proj_scales", gate_proj.scales)
self.register_buffer('gate_proj_qzeros', gate_proj.qzeros) self.register_buffer("gate_proj_qzeros", gate_proj.qzeros)
self.register_buffer('up_proj_qweight', up_proj.qweight) self.register_buffer("up_proj_qweight", up_proj.qweight)
self.register_buffer('up_proj_scales', up_proj.scales) self.register_buffer("up_proj_scales", up_proj.scales)
self.register_buffer('up_proj_qzeros', up_proj.qzeros) self.register_buffer("up_proj_qzeros", up_proj.qzeros)
self.in_features = gate_proj.in_features self.in_features = gate_proj.in_features
self.intermediate_size = gate_proj.out_features self.intermediate_size = gate_proj.out_features
...@@ -66,17 +68,13 @@ class QuantFusedMLP(nn.Module): ...@@ -66,17 +68,13 @@ class QuantFusedMLP(nn.Module):
x = routing_weights * x x = routing_weights * x
return x return x
class QuantLlamaMLP(QuantFusedMLP): class QuantLlamaMLP(QuantFusedMLP):
r""" r"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead. should always use `QuantFusedMLP` class instead.
""" """
def __init__(
self, def __init__(self, gate_proj, down_proj, up_proj):
gate_proj, super().__init__(gate_proj, down_proj, up_proj)
down_proj,
up_proj
):
super().__init__(gate_proj, down_proj, up_proj)
\ No newline at end of file
...@@ -2,8 +2,16 @@ import torch ...@@ -2,8 +2,16 @@ import torch
import torch.nn as nn import torch.nn as nn
from typing import List from typing import List
from awq.utils import fused_utils from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast, MoeModelOutputWithPast from transformers.modeling_outputs import (
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock, MixtralBlock BaseModelOutputWithPast,
MoeModelOutputWithPast,
)
from awq.modules.fused.block import (
MPTBlock,
FalconDecoderLayer,
LlamaLikeBlock,
MixtralBlock,
)
class MixtralModel(nn.Module): class MixtralModel(nn.Module):
...@@ -47,8 +55,10 @@ class MixtralModel(nn.Module): ...@@ -47,8 +55,10 @@ class MixtralModel(nn.Module):
h, h,
mask, mask,
) )
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) h, _, past_key_value = layer(
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm(h) h = self.norm(h)
return MoeModelOutputWithPast( return MoeModelOutputWithPast(
...@@ -65,6 +75,7 @@ class LlamaLikeModel(nn.Module): ...@@ -65,6 +75,7 @@ class LlamaLikeModel(nn.Module):
LlamaLikeModel is intended to be reused across models that have LlamaLikeModel is intended to be reused across models that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila. an architecture that closely resembles Llama, e.g. Mistral and Aquila.
""" """
def __init__(self, vocab_size, blocks, embedding, norm): def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
...@@ -72,12 +83,19 @@ class LlamaLikeModel(nn.Module): ...@@ -72,12 +83,19 @@ class LlamaLikeModel(nn.Module):
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks) self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm self.norm = norm
self.last_forward_num_tokens = 0 self.last_forward_num_tokens = 0
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, input_ids, self.last_forward_num_tokens
self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
...@@ -89,7 +107,7 @@ class LlamaLikeModel(nn.Module): ...@@ -89,7 +107,7 @@ class LlamaLikeModel(nn.Module):
seqlen=seqlen, seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos, start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device, device=input_ids.device,
type_as=h type_as=h,
) )
for layer in self.blocks: for layer in self.blocks:
...@@ -99,14 +117,17 @@ class LlamaLikeModel(nn.Module): ...@@ -99,14 +117,17 @@ class LlamaLikeModel(nn.Module):
mask, mask,
) )
h, _, past_key_value = layer( h, _, past_key_value = layer(
h, h, None, attention_mask=mask, is_causal=is_causal
None,
attention_mask=mask,
is_causal=is_causal
) )
h = self.norm(h) h = self.norm(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
class MPTModel(nn.Module): class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f): def __init__(self, vocab_size, blocks, wte, norm_f):
...@@ -120,10 +141,17 @@ class MPTModel(nn.Module): ...@@ -120,10 +141,17 @@ class MPTModel(nn.Module):
self.last_forward_num_tokens = 0 self.last_forward_num_tokens = 0
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(
self,
input_ids,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, input_ids, self.last_forward_num_tokens
self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
...@@ -135,7 +163,7 @@ class MPTModel(nn.Module): ...@@ -135,7 +163,7 @@ class MPTModel(nn.Module):
seqlen=seqlen, seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos, start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device, device=input_ids.device,
type_as=h type_as=h,
) )
for layer in self.blocks: for layer in self.blocks:
...@@ -145,14 +173,17 @@ class MPTModel(nn.Module): ...@@ -145,14 +173,17 @@ class MPTModel(nn.Module):
mask, mask,
) )
h, _, past_key_value = layer( h, _, past_key_value = layer(
h, h, None, attention_mask=mask, is_causal=is_causal
None,
attention_mask=mask,
is_causal=is_causal
) )
h = self.norm_f(h) h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
class FalconModel(nn.Module): class FalconModel(nn.Module):
def __init__(self, vocab_size, blocks, word_embeddings, ln_f): def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
...@@ -166,10 +197,17 @@ class FalconModel(nn.Module): ...@@ -166,10 +197,17 @@ class FalconModel(nn.Module):
self.last_forward_num_tokens = 0 self.last_forward_num_tokens = 0
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(
self,
input_ids,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids( input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, input_ids, self.last_forward_num_tokens
self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
...@@ -181,7 +219,7 @@ class FalconModel(nn.Module): ...@@ -181,7 +219,7 @@ class FalconModel(nn.Module):
seqlen=seqlen, seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos, start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device, device=input_ids.device,
type_as=h type_as=h,
) )
for layer in self.blocks: for layer in self.blocks:
...@@ -191,11 +229,13 @@ class FalconModel(nn.Module): ...@@ -191,11 +229,13 @@ class FalconModel(nn.Module):
mask, mask,
) )
h, _, past_key_value = layer( h, _, past_key_value = layer(
h, h, None, attention_mask=mask, is_causal=is_causal
None,
attention_mask=mask,
is_causal=is_causal
) )
h = self.ln_f(h) h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
import torch
import triton
from typing import Dict
import triton.language as tl
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class FusedSparseMoeBlock(torch.nn.Module):
def __init__(
self,
top_k,
gate,
ws,
w2s,
):
super().__init__()
self.gate = gate
self.top_k = top_k
self.ws = ws
self.w2s = w2s
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = apply_moe_weights(
self.ws,
self.w2s,
hidden_states,
router_logits,
self.top_k,
renormalize=True,
)
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
def apply_moe_weights(
w1: Dict[str, torch.Tensor],
w2: Dict[str, torch.Tensor],
x: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
dequant_w1 = awq_ext.dequantize_weights_cuda(
w1.qweight, w1.scales, w1.qzeros, 0, 0, 0, False
).permute(0, 2, 1)
dequant_w2 = awq_ext.dequantize_weights_cuda(
w2.qweight, w2.scales, w2.qzeros, 0, 0, 0, False
).permute(0, 2, 1)
return fused_moe(x, dequant_w1, dequant_w2, gating_output, topk, renormalize)
topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
(sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size(
topk_ids, 16, w1.qweight.shape[0]
)
x = x.view(x.shape[0], 1, *x.shape[1:])
gate_up = awq_ext.grouped_gemm_forward(
x,
w1.qweight,
w1.scales,
w1.qzeros,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
8,
)
out = torch.empty(
(gate_up.shape[:-1] + (gate_up.shape[-1] // 2,)), dtype=x.dtype, device=x.device
)
awq_ext.silu_and_mul(out, gate_up)
out = awq_ext.grouped_gemm_forward(
out,
w2.qweight,
w2.scales,
w2.qzeros,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
8,
)
return torch.sum(out, dim=1)
@triton.jit
def fused_moe_kernel(
# Pointers to matrices
a_ptr,
b_ptr,
c_ptr,
topk_weights_ptr,
sorted_token_ids_ptr,
expert_ids_ptr,
num_tokens_post_padded_ptr,
# Matrix dimensions
N,
K,
EM,
num_valid_tokens,
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. `stride_am` is how much to increase `a_ptr`
# by to get the element one row down (A has M rows).
stride_am,
stride_ak,
stride_be,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
# Meta-parameters
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
MUL_ROUTED_WEIGHT: tl.constexpr,
top_k: tl.constexpr,
compute_type: tl.constexpr,
):
"""
Implements the fused computation for a Mixture of Experts (MOE) using token and expert matrices.
Key Parameters:
- A: The input tensor representing tokens with shape (*, K), where '*' can be any shape representing batches and K is the feature dimension of each token.
- B: The stacked MOE weight tensor with shape (E, N, K), where E is the number of experts, K is the input feature dimension, and N is the output feature dimension.
- C: The output cache tensor with shape (M, topk, N), where M is the total number of tokens post padding, topk is the number of times each token is repeated,
and N is the output feature dimension.
- sorted_token_ids: A tensor containing the sorted indices of tokens, repeated topk times and arranged by the expert index they are assigned to.
- expert_ids: A tensor containing the indices of the expert for each block. It determines which expert matrix from B should be used for each block in A.
This kernel performs the multiplication of a token by its corresponding expert matrix as determined by `expert_ids`. The sorting of `sorted_token_ids`
by expert index and padding ensures divisibility by BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix multiplication across different blocks processed by the same expert.
"""
# -----------------------------------------------------------
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
# We will advance this pointer as we move in the K direction
# and accumulate
# `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
# `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return
offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
# -----------------------------------------------------------
# Iterate to compute a block of the C matrix.
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
# of fp32 values for higher accuracy.
# `accumulator` will be converted back to fp16 after the loop.
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the K dimension.
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
accumulator += tl.dot(a, b)
# Advance the ptrs to the next K block.
a_ptrs += BLOCK_SIZE_K * stride_ak
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
accumulator = accumulator.to(compute_type)
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: int):
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1),),
dtype=torch.int32,
device=topk_ids.device,
)
expert_ids = torch.empty(
(topk_ids.numel() + num_experts,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
awq_ext.moe_alig_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: dict,
):
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
grid = lambda META: (
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid](
A,
B,
C,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
C.stride(1),
C.stride(2),
MUL_ROUTED_WEIGHT=mul_routed_weight,
top_k=top_k,
compute_type=tl.bfloat16 if A.dtype == torch.bfloat16 else tl.float16,
**config,
)
def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
"""Compute top-k indice and weights from gating logits
Args:
gating_output (torch.Tensor): The output of the gating operation (before softmax).
topk (int): The number of top-k experts to select.
renormalize (bool): If True, renormalize the top-k weights to sum to 1.
"""
M = gating_output.shape[0]
if torch.version.hip is not None:
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=gating_output.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=gating_output.device
)
awq_ext.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_moe(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
inplace: bool = True,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of weights, w1 and w2, and top-k gating mechanism.
Parameters:
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
- w1 (torch.Tensor): The first set of expert weights.
- w2 (torch.Tensor): The second set of expert weights.
- gating_output (torch.Tensor): The output of the gating operation (before softmax).
- topk (int): The number of top-k experts to select.
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
- inplace (bool): If True, perform the operation in-place. Defaults to False.
Returns:
- torch.Tensor: The output tensor after applying the MoE layer.
"""
# Check constraints.
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
# assert w1.is_contiguous(), "Expert weights1 must be contiguous"
# assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
E, N, _ = w1.shape
topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
config = {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
if topk_ids.numel() <= w1.shape[0]:
config = {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config["BLOCK_SIZE_M"], E
)
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
)
awq_ext.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
)
if inplace:
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
...@@ -2,4 +2,4 @@ from .exllama import WQLinear_Exllama ...@@ -2,4 +2,4 @@ from .exllama import WQLinear_Exllama
from .exllamav2 import WQLinear_ExllamaV2 from .exllamav2 import WQLinear_ExllamaV2
from .gemm import WQLinear_GEMM from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin from .marlin import WQLinear_Marlin
\ No newline at end of file
...@@ -11,6 +11,7 @@ try: ...@@ -11,6 +11,7 @@ try:
except: except:
AWQ_INSTALLED = False AWQ_INSTALLED = False
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev # Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function): class WQLinearMMFunction(Function):
@staticmethod @staticmethod
...@@ -24,45 +25,29 @@ class WQLinearMMFunction(Function): ...@@ -24,45 +25,29 @@ class WQLinearMMFunction(Function):
w_bit=4, w_bit=4,
group_size=128, group_size=128,
bias=None, bias=None,
out_features=0 out_features=0,
): ):
# The forward pass can use ctx. # The forward pass can use ctx.
ctx.save_for_backward(x, qweight, qzeros, scales, bias) ctx.save_for_backward(x, qweight, qzeros, scales, bias)
ctx.out_features = out_features ctx.out_features = out_features
out_shape = x.shape[:-1] + (out_features, ) out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16) x = x.to(torch.float16)
if AWQ_INSTALLED: if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024 FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION: if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda( out = awq_ext.dequantize_weights_cuda(
qweight, qweight, scales, qzeros, 0, 0, 0, False
scales,
qzeros,
0,
0,
0,
False
) )
out = torch.matmul(x, out) out = torch.matmul(x, out)
else: else:
out = awq_ext.gemm_forward_cuda( out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
qweight,
scales,
qzeros,
8
) )
else: else:
out = dequantize_gemm( out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
qweight,
qzeros,
scales,
w_bit,
group_size
)
out = torch.matmul(x, out) out = torch.matmul(x, out)
out = out + bias if bias is not None else out out = out + bias if bias is not None else out
...@@ -71,7 +56,7 @@ class WQLinearMMFunction(Function): ...@@ -71,7 +56,7 @@ class WQLinearMMFunction(Function):
# always want 3D tensor if tensor is 2D # always want 3D tensor if tensor is 2D
if len(out.shape) == 2: if len(out.shape) == 2:
out = out.unsqueeze(0) out = out.unsqueeze(0)
return out return out
@staticmethod @staticmethod
...@@ -79,13 +64,7 @@ class WQLinearMMFunction(Function): ...@@ -79,13 +64,7 @@ class WQLinearMMFunction(Function):
input, qweight, qzeros, scales, bias = ctx.saved_tensors input, qweight, qzeros, scales, bias = ctx.saved_tensors
weights = awq_ext.dequantize_weights_cuda( weights = awq_ext.dequantize_weights_cuda(
qweight, qweight, scales, qzeros, 1, 0, 0, False
scales,
qzeros,
1,
0,
0,
False
) )
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
...@@ -98,7 +77,9 @@ class WQLinearMMFunction(Function): ...@@ -98,7 +77,9 @@ class WQLinearMMFunction(Function):
class WQLinear_GEMM(nn.Module): class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev, training=False): def __init__(
self, w_bit, group_size, in_features, out_features, bias, dev, training=False
):
super().__init__() super().__init__()
if w_bit not in [4]: if w_bit not in [4]:
......
...@@ -54,7 +54,7 @@ class WQLinear_Marlin(nn.Module): ...@@ -54,7 +54,7 @@ class WQLinear_Marlin(nn.Module):
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features self.group_size = group_size if group_size != -1 else in_features
self.max_par = 8 # partitioning for large inputs self.max_par = 8 # partitioning for large inputs
# quick sanity check (make sure aligment) # quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0 assert self.in_features % self.group_size == 0
......
...@@ -117,7 +117,7 @@ class AwqQuantizer: ...@@ -117,7 +117,7 @@ class AwqQuantizer:
best_device = "cuda:" + str(i % torch.cuda.device_count()) best_device = "cuda:" + str(i % torch.cuda.device_count())
else: else:
best_device = get_best_device() best_device = get_best_device()
self.modules[i] = self.modules[i].to(best_device) self.modules[i] = self.modules[i].to(best_device)
common_device = next(self.modules[i].parameters()).device common_device = next(self.modules[i].parameters()).device
...@@ -190,15 +190,15 @@ class AwqQuantizer: ...@@ -190,15 +190,15 @@ class AwqQuantizer:
linear_layer.weight.data linear_layer.weight.data
) )
if self.version == "GEMM": if self.version == "gemm":
scales = scales.t().contiguous() scales = scales.t().contiguous()
zeros = zeros.t().contiguous() zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM q_linear_module = WQLinear_GEMM
elif self.version == "GEMV": elif self.version == "gemv":
q_linear_module = WQLinear_GEMV q_linear_module = WQLinear_GEMV
elif self.version == "Marlin": elif self.version == "marlin":
q_linear_module = WQLinear_Marlin q_linear_module = WQLinear_Marlin
else: else:
...@@ -355,7 +355,9 @@ class AwqQuantizer: ...@@ -355,7 +355,9 @@ class AwqQuantizer:
continue continue
named_linears[name].to(get_best_device()) named_linears[name].to(get_best_device())
max_val = self._compute_best_clip(named_linears[name].weight, input_feat[name]) max_val = self._compute_best_clip(
named_linears[name].weight, input_feat[name]
)
clip_list.append((name, max_val)) clip_list.append((name, max_val))
named_linears[name].cpu() named_linears[name].cpu()
...@@ -481,7 +483,9 @@ class AwqQuantizer: ...@@ -481,7 +483,9 @@ class AwqQuantizer:
clear_memory() clear_memory()
if layer_kwargs.get("attention_mask") is not None: if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(best_device) layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
best_device
)
return modules, layer_kwargs, inps return modules, layer_kwargs, inps
......
This diff is collapsed.
...@@ -3,33 +3,41 @@ import logging ...@@ -3,33 +3,41 @@ import logging
from typing import List, Union from typing import List, Union
from datasets import load_dataset from datasets import load_dataset
def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None, n_samples=512, block_size=512, def get_calib_dataset(
split="train", text_column="text"): data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None,
n_samples=512,
block_size=512,
split="train",
text_column="text",
):
if isinstance(data, str): if isinstance(data, str):
if data == "pileval": if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation") dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else: else:
dataset = load_dataset(data, split=split) dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42) dataset = dataset.shuffle(seed=42)
elif isinstance(data, list): elif isinstance(data, list):
if isinstance(data[0], str): if isinstance(data[0], str):
dataset = [{text_column: text} for text in data] dataset = [{text_column: text} for text in data]
elif isinstance(data[0][0], int): elif isinstance(data[0][0], int):
dataset = data dataset = data
else: else:
raise NotImplementedError( raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list" "Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element" "that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.") " or a list of list of int for tokenized words."
)
else: else:
raise NotImplementedError( raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list" "Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element" "that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words.") " or a list of list of int for tokenized words."
)
samples = [] samples = []
n_run = 0 n_run = 0
for data in dataset: for data in dataset:
...@@ -52,4 +60,6 @@ def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval", ...@@ -52,4 +60,6 @@ def get_calib_dataset(data: Union[str, List[str], List[List[int]]] = "pileval",
cat_samples = torch.cat(samples, dim=1) cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size n_split = cat_samples.shape[1] // block_size
logging.debug(f" * Split into {n_split} blocks") logging.debug(f" * Split into {n_split} blocks")
return [cat_samples[:, i*block_size:(i+1)*block_size] for i in range(n_split)] return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split)
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment