Unverified Commit 2fcbf268 authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

Exllama kernels support (#313)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent a3db8099
...@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import ( ...@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldAquilaDecoderLayer, LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as OldAquilaForCausalLM LlamaForCausalLM as OldAquilaForCausalLM
) )
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class AquilaAWQForCausalLM(BaseAWQForCausalLM): class AquilaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -95,11 +94,6 @@ class AquilaFuser: ...@@ -95,11 +94,6 @@ class AquilaFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight,
module.input_layernorm.variance_epsilon module.input_layernorm.variance_epsilon
...@@ -114,7 +108,7 @@ class AquilaFuser: ...@@ -114,7 +108,7 @@ class AquilaFuser:
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv, qkv_layer=qkv,
o_proj=module.self_attn.o_proj, o_proj=module.self_attn.o_proj,
mlp=mlp, mlp=module.mlp,
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
...@@ -127,3 +121,5 @@ class AquilaFuser: ...@@ -127,3 +121,5 @@ class AquilaFuser:
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, self.model.model.norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks)
...@@ -23,39 +23,76 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -23,39 +23,76 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"llava": LlavaAWQForCausalLM, "llava": LlavaAWQForCausalLM,
} }
def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs): def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs) config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs
)
if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys(): if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
raise TypeError(f"{config.model_type} isn't supported yet.") raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type model_type = config.model_type
return model_type return model_type
class AutoAWQForCausalLM: class AutoAWQForCausalLM:
def __init__(self): def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n' raise EnvironmentError(
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained') "You must instantiate AutoAWQForCausalLM with\n"
"AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained"
)
@classmethod @classmethod
def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False, def from_pretrained(
device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM: self,
model_type = check_and_get_model_type(model_path, trust_remote_code, **model_init_kwargs) model_path,
trust_remote_code=True,
safetensors=False,
device_map=None,
**model_init_kwargs,
) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(
model_path, trust_remote_code, **model_init_kwargs
)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors, model_path,
device_map=device_map, **model_init_kwargs model_type,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
device_map=device_map,
**model_init_kwargs,
) )
@classmethod @classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None, def from_quantized(
trust_remote_code=True, fuse_layers=True, self,
batch_size=1, safetensors=True, quant_path,
device_map="balanced", offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM: quant_filename="",
max_new_tokens=None,
trust_remote_code=True,
fuse_layers=True,
use_exllama=False,
use_exllama_v2=False,
batch_size=1,
safetensors=True,
device_map="balanced",
offload_folder=None,
**config_kwargs,
) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, quant_path,
fuse_layers=fuse_layers, safetensors=safetensors, model_type,
device_map=device_map, offload_folder=offload_folder, quant_filename,
**config_kwargs max_new_tokens,
trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
safetensors=safetensors,
device_map=device_map,
offload_folder=offload_folder,
**config_kwargs,
) )
...@@ -5,9 +5,7 @@ from awq.modules.fused.block import LlamaLikeBlock ...@@ -5,9 +5,7 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import ( from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer, LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
) )
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class BaichuanAWQForCausalLM(BaseAWQForCausalLM): class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
...@@ -102,11 +100,6 @@ class BaichuanFuser: ...@@ -102,11 +100,6 @@ class BaichuanFuser:
# module.self_attn.v_proj # module.self_attn.v_proj
# ) # )
qkv = module.self_attn.W_pack qkv = module.self_attn.W_pack
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight,
module.input_layernorm.epsilon module.input_layernorm.epsilon
...@@ -121,7 +114,7 @@ class BaichuanFuser: ...@@ -121,7 +114,7 @@ class BaichuanFuser:
n_kv_heads=self.model.config.num_attention_heads, n_kv_heads=self.model.config.num_attention_heads,
qkv_layer=qkv, qkv_layer=qkv,
o_proj=module.self_attn.o_proj, o_proj=module.self_attn.o_proj,
mlp=mlp, mlp=module.mlp,
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
...@@ -135,3 +128,5 @@ class BaichuanFuser: ...@@ -135,3 +128,5 @@ class BaichuanFuser:
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, self.model.model.norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks)
import os import os
import gc import gc
import json import json
import time
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -14,6 +14,8 @@ import transformers ...@@ -14,6 +14,8 @@ import transformers
from transformers.modeling_utils import shard_checkpoint from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.utils.module import ( from awq.utils.module import (
get_named_linears, get_named_linears,
set_op_by_name, set_op_by_name,
...@@ -34,6 +36,8 @@ from accelerate.big_modeling import ( ...@@ -34,6 +36,8 @@ from accelerate.big_modeling import (
from awq.models._config import AwqConfig from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
from awq.quantize.quantizer import AwqQuantizer from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name from awq.utils.module import get_named_linears, set_op_by_name
...@@ -59,12 +63,15 @@ TRANSFORMERS_AUTO_MAPPING_DICT = { ...@@ -59,12 +63,15 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"llava": "AutoModelForVision2Seq", "llava": "AutoModelForVision2Seq",
} }
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config, processor): def __init__(
self, model, model_type, is_quantized, config, quant_config, processor
):
super().__init__() super().__init__()
self.model:PreTrainedModel = model self.model: PreTrainedModel = model
self.model_type:str = model_type self.model_type: str = model_type
self.is_quantized:bool = is_quantized self.is_quantized: bool = is_quantized
self.search_result = None self.search_result = None
self.config: PretrainedConfig = config self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config self.quant_config: AwqConfig = quant_config
...@@ -81,16 +88,32 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -81,16 +88,32 @@ class BaseAWQForCausalLM(nn.Module):
return self.model.generate(*args, **kwargs) return self.model.generate(*args, **kwargs)
@torch.no_grad() @torch.no_grad()
def quantize(self, tokenizer=None, quant_config={}, def quantize(
calib_data: Union[str, List[str]]="pileval", self,
split="train", text_column="text", duo_scaling=True, tokenizer=None,
modules_to_not_convert=None, export_compatible=False): quant_config={},
calib_data: Union[str, List[str]] = "pileval",
split="train",
text_column="text",
duo_scaling=True,
modules_to_not_convert=None,
export_compatible=False,
):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config) self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
self.quantizer = AwqQuantizer( self.quantizer = AwqQuantizer(
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size, self,
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert, self.model,
export_compatible=export_compatible tokenizer,
self.quant_config.w_bit,
self.quant_config.q_group_size,
self.quant_config.version,
calib_data,
split,
text_column,
duo_scaling,
modules_to_not_convert=modules_to_not_convert,
export_compatible=export_compatible,
) )
self.quantizer.quantize() self.quantizer.quantize()
...@@ -118,12 +141,15 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -118,12 +141,15 @@ class BaseAWQForCausalLM(nn.Module):
pass pass
def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"): def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
# Save model # Save model
class EmptyModule(nn.Module): class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__() def __init__(self):
def forward(self, x): return x super(EmptyModule, self).__init__()
def forward(self, x):
return x
# Save model and config files with empty state dict # Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict() self.model.config.quantization_config = self.quant_config.to_transformers_dict()
...@@ -135,42 +161,51 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -135,42 +161,51 @@ class BaseAWQForCausalLM(nn.Module):
self.processor.save_pretrained(save_dir) self.processor.save_pretrained(save_dir)
# Remove empty state dict # Remove empty state dict
default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin'] default_paths = [
f"{save_dir}/model.safetensors",
f"{save_dir}/pytorch_model.bin",
]
for path in default_paths: for path in default_paths:
if os.path.exists(path): if os.path.exists(path):
os.remove(path) os.remove(path)
# model_name has no extension, add it when saving state_dict # model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin' model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
# shard checkpoint into chunks (10GB default) # shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint( shards, index = shard_checkpoint(
self.model.state_dict(), self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
max_shard_size=shard_size,
weights_name=model_name
) )
for shard_file, shard in shards.items(): for shard_file, shard in shards.items():
if safetensors: if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory # safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()} shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}) save_file(
shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
)
else: else:
torch.save(shard, os.path.join(save_dir, shard_file)) torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index # save shard index
if index is not None: if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file: with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
file.write(json.dumps(index, indent=4)) file.write(json.dumps(index, indent=4))
@classmethod @classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16, def from_pretrained(
trust_remote_code=True, safetensors=False, device_map=None, self,
**model_init_kwargs): model_path,
model_type,
torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True,
safetensors=False,
device_map=None,
**model_init_kwargs,
):
# Get weights path and quant config # Get weights path and quant config
model_weights_path, config, quant_config = self._load_config( model_weights_path, config, quant_config = self._load_config(
self, model_path, '', safetensors, trust_remote_code=trust_remote_code self, model_path, "", safetensors, trust_remote_code=trust_remote_code
) )
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type] target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
...@@ -188,26 +223,49 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -188,26 +223,49 @@ class BaseAWQForCausalLM(nn.Module):
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_safetensors=safetensors, use_safetensors=safetensors,
device_map=device_map, device_map=device_map,
**model_init_kwargs **model_init_kwargs,
) )
model.eval() model.eval()
return self(model, model_type, is_quantized=False, config=config, return self(
quant_config=quant_config, processor=processor) model,
model_type,
is_quantized=False,
config=config,
quant_config=quant_config,
processor=processor,
)
@classmethod @classmethod
def from_quantized(self, model_path, model_type, model_filename='', def from_quantized(
max_new_tokens=None, torch_dtype=torch.float16, self,
trust_remote_code=True, safetensors=True, is_quantized=True, model_path,
fuse_layers=False, version='GEMM', model_type,
device_map="balanced", offload_folder=None, model_filename="",
**config_kwargs): max_new_tokens=None,
torch_dtype=torch.float16,
trust_remote_code=True,
safetensors=True,
is_quantized=True,
fuse_layers=False,
use_exllama=False,
use_exllama_v2=False,
version="GEMM",
device_map="balanced",
offload_folder=None,
**config_kwargs,
):
# [STEP 1-2] Load weights path and configs # [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config( model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version, self,
trust_remote_code, max_new_tokens=max_new_tokens, model_path,
**config_kwargs model_filename,
safetensors,
version,
trust_remote_code,
max_new_tokens=max_new_tokens,
**config_kwargs,
) )
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type] target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
...@@ -215,10 +273,21 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -215,10 +273,21 @@ class BaseAWQForCausalLM(nn.Module):
# [STEP 3] Load model # [STEP 3] Load model
with init_empty_weights(): with init_empty_weights():
model = target_cls.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) model = target_cls.from_config(
config=config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
# Prepare WQLinear layers, replace nn.Linear # Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config.version) self._load_quantized_modules(
self,
model,
quant_config,
quant_config.version,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
)
model.tie_weights() model.tie_weights()
...@@ -237,12 +306,37 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -237,12 +306,37 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
return self(model, model_type, is_quantized=is_quantized, config=config, if use_exllama:
quant_config=quant_config, processor=None) # creates q4 handle
model = exllama_post_init(model)
elif use_exllama_v2:
# creates q4 handle and allocates scratch spaces wrt max_input_len and
# max_batch_size, which are hardcoded for now but might be worth interfacing
model = exllamav2_post_init(
model,
max_input_len=max_new_tokens,
max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1))
)
return self(
model,
model_type,
is_quantized=is_quantized,
config=config,
quant_config=quant_config,
processor=None,
)
def _load_config(self, model_path, model_filename, safetensors=True, def _load_config(
version="GEMM", trust_remote_code=True, max_new_tokens=4096, self,
**config_kwargs): model_path,
model_filename,
safetensors=True,
version="GEMM",
trust_remote_code=True,
max_new_tokens=4096,
**config_kwargs,
):
# [STEP 1] Download model if path is not a directory # [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path): if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"] ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
...@@ -253,8 +347,8 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -253,8 +347,8 @@ class BaseAWQForCausalLM(nn.Module):
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns) model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
if model_filename != '': if model_filename != "":
model_weights_path = model_path + f'/{model_filename}' model_weights_path = model_path + f"/{model_filename}"
else: else:
model_weights_path = model_path model_weights_path = model_path
...@@ -263,22 +357,33 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -263,22 +357,33 @@ class BaseAWQForCausalLM(nn.Module):
quant_config = AwqConfig.from_pretrained(model_path) quant_config = AwqConfig.from_pretrained(model_path)
# Load model config and set max generation length # Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'): if max_new_tokens is None and hasattr(self, "max_new_tokens_key"):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs) config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048) config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
# To add the generate support for Multi-modal models as well # To add the generate support for Multi-modal models as well
if hasattr(config, "text_config"): if hasattr(config, "text_config"):
config.text_config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048) config.text_config.max_new_tokens = getattr(
config, self.max_new_tokens_key, 2048
)
else: else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs) config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = max_new_tokens config.max_new_tokens = max_new_tokens
return model_weights_path, config, quant_config return model_weights_path, config, quant_config
def _load_quantized_modules(self, model, quant_config, version): def _load_quantized_modules(
self, model, quant_config, version, use_exllama, use_exllama_v2
):
# Real quantization of weights # Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now." assert quant_config.zero_point, "We only support zero_point quantization now."
assert not (
version == "GEMV" and (use_exllama or use_exllama_v2)
), "Exllama kernels only support GEMM version."
# Get blocks of model # Get blocks of model
layers = self.get_model_layers(model) layers = self.get_model_layers(model)
...@@ -290,23 +395,26 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -290,23 +395,26 @@ class BaseAWQForCausalLM(nn.Module):
named_linears = get_named_linears(layer) named_linears = get_named_linears(layer)
# Filter out the linear layers we don't want to exclude # Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert) named_linears = exclude_layers_to_not_quantize(
named_linears, quant_config.modules_to_not_convert
)
# Replace activation functions # Replace activation functions
self._scale_activations(self, layer) self._scale_activations(self, layer)
# Replace nn.Linear with WQLinear # Replace nn.Linear with WQLinear
for name, module in named_linears.items(): for name, module in named_linears.items():
if version == 'GEMM': if use_exllama:
q_linear_module = WQLinear_Exllama
elif use_exllama_v2:
q_linear_module = WQLinear_ExllamaV2
elif version == "GEMM":
q_linear_module = WQLinear_GEMM q_linear_module = WQLinear_GEMM
elif version == 'GEMV': elif version == "GEMV":
q_linear_module = WQLinear_GEMV q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear( q_linear = q_linear_module.from_linear(
module, module, quant_config.w_bit, quant_config.q_group_size, True
quant_config.w_bit,
quant_config.q_group_size,
True
) )
q_linear.to(next(layer.parameters()).device) q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear) set_op_by_name(layer, name, q_linear)
...@@ -318,13 +426,15 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -318,13 +426,15 @@ class BaseAWQForCausalLM(nn.Module):
def _scale_activations(self, layer): def _scale_activations(self, layer):
scale_dict = self.get_act_for_scaling(layer) scale_dict = self.get_act_for_scaling(layer)
if scale_dict['is_scalable']: if scale_dict["is_scalable"]:
if not isinstance(scale_dict['scale_layer'], ScaledActivation): if not isinstance(scale_dict["scale_layer"], ScaledActivation):
param = next(layer.parameters()) param = next(layer.parameters())
# get activation scale # get activation scale
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device) scale_like = torch.ones(
scale_dict["scale_shape"], dtype=param.dtype, device=param.device
)
# scale activation # scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like) scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act) set_op_by_name(layer, scale_dict["scale_name"], scaled_act)
...@@ -109,3 +109,5 @@ class FalconFuser: ...@@ -109,3 +109,5 @@ class FalconFuser:
self.model.transformer.word_embeddings, self.model.transformer.word_embeddings,
self.model.transformer.ln_f, self.model.transformer.ln_f,
) )
setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
\ No newline at end of file
...@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import ( ...@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer, LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM LlamaForCausalLM as OldLlamaForCausalLM
) )
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM): class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -95,11 +94,6 @@ class LlamaFuser: ...@@ -95,11 +94,6 @@ class LlamaFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight,
module.input_layernorm.variance_epsilon module.input_layernorm.variance_epsilon
...@@ -114,7 +108,7 @@ class LlamaFuser: ...@@ -114,7 +108,7 @@ class LlamaFuser:
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv, qkv_layer=qkv,
o_proj=module.self_attn.o_proj, o_proj=module.self_attn.o_proj,
mlp=mlp, mlp=module.mlp,
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
...@@ -128,3 +122,4 @@ class LlamaFuser: ...@@ -128,3 +122,4 @@ class LlamaFuser:
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, self.model.model.norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks)
\ No newline at end of file
...@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import ( ...@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer, LlamaDecoderLayer as OldLlamaDecoderLayer,
) )
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration as OldLlavaForConditionalGeneration from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration as OldLlavaForConditionalGeneration
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlavaAWQForCausalLM(BaseAWQForCausalLM): class LlavaAWQForCausalLM(BaseAWQForCausalLM):
...@@ -95,11 +94,6 @@ class LlavaFuser: ...@@ -95,11 +94,6 @@ class LlavaFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight,
module.input_layernorm.variance_epsilon module.input_layernorm.variance_epsilon
...@@ -114,16 +108,17 @@ class LlavaFuser: ...@@ -114,16 +108,17 @@ class LlavaFuser:
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv, qkv_layer=qkv,
o_proj=module.self_attn.o_proj, o_proj=module.self_attn.o_proj,
mlp=mlp, mlp=module.mlp,
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
max_seq_len=self.model.config.max_new_tokens max_seq_len=self.model.config.max_new_tokens
)) ))
self.model = LlamaLikeModel( self.model.model = LlamaLikeModel(
self.model.config.vocab_size, self.model.config.vocab_size,
blocks, blocks,
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, self.model.model.norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks)
...@@ -8,7 +8,6 @@ from transformers.models.mistral.modeling_mistral import ( ...@@ -8,7 +8,6 @@ from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer, MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM MistralForCausalLM as OldMistralForCausalLM
) )
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM): class MistralAWQForCausalLM(BaseAWQForCausalLM):
...@@ -95,11 +94,6 @@ class MistralFuser: ...@@ -95,11 +94,6 @@ class MistralFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight,
module.input_layernorm.variance_epsilon module.input_layernorm.variance_epsilon
...@@ -114,7 +108,7 @@ class MistralFuser: ...@@ -114,7 +108,7 @@ class MistralFuser:
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv, qkv_layer=qkv,
o_proj=module.self_attn.o_proj, o_proj=module.self_attn.o_proj,
mlp=mlp, mlp=module.mlp,
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
...@@ -127,3 +121,4 @@ class MistralFuser: ...@@ -127,3 +121,4 @@ class MistralFuser:
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, self.model.model.norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks)
...@@ -8,7 +8,6 @@ from transformers.models.mixtral.modeling_mixtral import ( ...@@ -8,7 +8,6 @@ from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer, MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM MixtralForCausalLM as OldMixtralForCausalLM
) )
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class MixtralAWQForCausalLM(BaseAWQForCausalLM): class MixtralAWQForCausalLM(BaseAWQForCausalLM):
...@@ -98,14 +97,6 @@ class MixtralFuser: ...@@ -98,14 +97,6 @@ class MixtralFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = QuantFusedMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight, module.input_layernorm.weight,
module.input_layernorm.variance_epsilon module.input_layernorm.variance_epsilon
...@@ -134,4 +125,5 @@ class MixtralFuser: ...@@ -134,4 +125,5 @@ class MixtralFuser:
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, self.model.model.norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks)
...@@ -105,3 +105,5 @@ class MptFuser: ...@@ -105,3 +105,5 @@ class MptFuser:
self.model.transformer.wte, self.model.transformer.wte,
self.model.transformer.norm_f, self.model.transformer.norm_f,
) )
setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
\ No newline at end of file
...@@ -4,7 +4,6 @@ from .base import BaseAWQForCausalLM ...@@ -4,7 +4,6 @@ from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm from awq.modules.fused.norm import FasterTransformerRMSNorm
class YiAWQForCausalLM(BaseAWQForCausalLM): class YiAWQForCausalLM(BaseAWQForCausalLM):
...@@ -90,11 +89,6 @@ class YiFuser: ...@@ -90,11 +89,6 @@ class YiFuser:
module.self_attn.k_proj, module.self_attn.k_proj,
module.self_attn.v_proj module.self_attn.v_proj
) )
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm( norm_1 = FasterTransformerRMSNorm(
module.ln1.weight, module.ln1.weight,
module.ln1.variance_epsilon module.ln1.variance_epsilon
...@@ -109,7 +103,7 @@ class YiFuser: ...@@ -109,7 +103,7 @@ class YiFuser:
n_kv_heads=self.model.config.num_key_value_heads, n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv, qkv_layer=qkv,
o_proj=module.self_attn.o_proj, o_proj=module.self_attn.o_proj,
mlp=mlp, mlp=module.mlp,
norm_1=norm_1, norm_1=norm_1,
norm_2=norm_2, norm_2=norm_2,
dev=device, dev=device,
...@@ -123,3 +117,4 @@ class YiFuser: ...@@ -123,3 +117,4 @@ class YiFuser:
self.model.model.embed_tokens, self.model.model.embed_tokens,
self.model.model.norm, self.model.model.norm,
) )
setattr(self.model.model, "blocks", self.model.model.blocks)
import torch
import torch.nn as nn
from awq.utils.exllama_utils import unpack_reorder_pack
import exl_ext # with CUDA kernels (AutoAWQ_kernels)
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
class WQLinear_Exllama(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for Exllama kernels")
self.q4 = None
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
##################################################################################
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
##################################################################################
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.qweight, self.qzeros = unpack_reorder_pack(
self.qweight, self.qzeros, self.w_bit
)
self.q4 = exl_ext.make_q4(
self.qweight,
self.qzeros,
self.scales,
none_tensor, # g_idx
self.qweight.device.index, # device index
)
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
raise NotImplementedError("Only inference is supported for Exllama kernels")
def forward(self, x):
assert self.q4 is not None, (
"module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model."
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
if input_dtype != torch.float16:
x = x.to(dtype=torch.float16)
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
exl_ext.q4_matmul(x, self.q4, out)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
def exllama_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Exllama):
submodule.post_init()
return model
import torch
import torch.nn as nn
from typing import Dict
from awq.utils.exllama_utils import unpack_reorder_pack
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
class WQLinear_ExllamaV2(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.q_handle = None
self.q_tensors = None
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
##################################################################################
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
##################################################################################
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
def post_init(self, scratch_space: "ScratchSpace"):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.qweight, self.qzeros = unpack_reorder_pack(
self.qweight, self.qzeros, self.w_bit
)
temp_dq_size = self.temp_dq_size()
temp_dq = scratch_space.get_slice(temp_dq_size)
self.q_handle = exlv2_ext.make_q_matrix(
self.qweight,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
self.qzeros,
self.scales,
none_tensor,
temp_dq,
)
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
raise NotImplementedError("Only inference is supported for ExllamaV2 kernels")
def temp_dq_size(self):
"""
Returns the size of the temporary buffer required for the dq kernel.
"""
return self.in_features * self.out_features * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
"""
Returns the size of the temporary buffer required for the fwd kernel.
"""
return self.out_features * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
"""
Returns the size of the fixed scratch space required for the kernel.
"""
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
def forward(self, x):
assert self.q_handle is not None, (
"module.post_init() must be called before module.forward(). "
"Use exllamav2_post_init() on the whole model."
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
if input_dtype != torch.float16:
x = x.to(dtype=torch.float16)
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
exlv2_ext.gemm_half_q_half(x, self.q_handle, out, False)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
class ScratchSpace:
def __init__(self, scratch_bytes, dev):
self.scratch_bytes = scratch_bytes
self.scratch = torch.empty(
self.scratch_bytes // 2,
dtype=torch.float16,
device=dev,
)
def get_slice(self, size_bytes):
size_halfs = next_multiple(size_bytes, 128) // 2
scratch_slice = self.scratch.narrow(0, 0, size_halfs)
return scratch_slice
def exllamav2_post_init(model, max_input_len: int = 2048, max_batch_size: int = 8):
# we search for the maximum number of bytes required for each device's scratch space
fixed_bytes: Dict[torch.device, int] = {}
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_ExllamaV2):
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed(
max_input_len=max_input_len, max_batch_size=max_batch_size
)
fixed_bytes[device] = max(fixed_bytes.get(device, 0), scratch_fixed)
# we allocate a model-persistent scratch space for each device
model.scratch_spaces: Dict[torch.device, ScratchSpace] = {}
for device, scratch_bytes in fixed_bytes.items():
model.scratch_spaces[device] = ScratchSpace(scratch_bytes, device)
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_ExllamaV2):
device = submodule.qweight.device
submodule.post_init(scratch_space=model.scratch_spaces[device])
return model
def next_multiple(x, multiple):
return ((x + multiple - 1) // multiple) * multiple
...@@ -69,7 +69,7 @@ class LlamaLikeModel(nn.Module): ...@@ -69,7 +69,7 @@ class LlamaLikeModel(nn.Module):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.embedding = embedding self.embedding = embedding
self.blocks: List[LlamaLikeBlock] = blocks self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm self.norm = norm
self.last_forward_num_tokens = 0 self.last_forward_num_tokens = 0
......
import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import awq_ext # with CUDA kernels import awq_ext # with CUDA kernels
def make_divisible(c, divisor): def make_divisible(c, divisor):
return (c + divisor - 1) // divisor return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8): def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128: if group_size >= 128:
size_multiplier = 1 size_multiplier = 1
...@@ -21,6 +22,7 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8): ...@@ -21,6 +22,7 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
base_width = make_divisible(base_width, size_multiplier) * size_multiplier base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width return base_width
class WQLinear_GEMM(nn.Module): class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__() super().__init__()
...@@ -37,17 +39,54 @@ class WQLinear_GEMM(nn.Module): ...@@ -37,17 +39,54 @@ class WQLinear_GEMM(nn.Module):
assert self.in_features % self.group_size == 0 assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0 assert out_features % (32 // self.w_bit) == 0
self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) self.register_buffer(
self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) "qweight",
self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev)) torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias: if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else: else:
self.bias = None self.bias = None
@classmethod @classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None): def from_linear(
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device) cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd if init_only: # just prepare for loading sd
return awq_linear return awq_linear
...@@ -63,11 +102,20 @@ class WQLinear_GEMM(nn.Module): ...@@ -63,11 +102,20 @@ class WQLinear_GEMM(nn.Module):
intweight = [] intweight = []
for idx in range(awq_linear.in_features): for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None]) intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous() intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32) intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device) qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num): for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4: if awq_linear.w_bit == 4:
...@@ -80,7 +128,11 @@ class WQLinear_GEMM(nn.Module): ...@@ -80,7 +128,11 @@ class WQLinear_GEMM(nn.Module):
awq_linear.qweight = qweight awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32) zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device) qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num): for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4: if awq_linear.w_bit == 4:
...@@ -96,13 +148,15 @@ class WQLinear_GEMM(nn.Module): ...@@ -96,13 +148,15 @@ class WQLinear_GEMM(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype input_dtype = x.dtype
if input_dtype != torch.float16: if input_dtype != torch.float16:
x = x.half() x = x.half()
out = awq_ext.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
if input_dtype != torch.float16: if input_dtype != torch.float16:
out = out.to(dtype=input_dtype) out = out.to(dtype=input_dtype)
...@@ -111,8 +165,14 @@ class WQLinear_GEMM(nn.Module): ...@@ -111,8 +165,14 @@ class WQLinear_GEMM(nn.Module):
return out.reshape(out_shape) return out.reshape(out_shape)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( return (
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
) )
...@@ -132,19 +192,52 @@ class WQLinear_GEMV(nn.Module): ...@@ -132,19 +192,52 @@ class WQLinear_GEMV(nn.Module):
# quick sanity check (make sure aligment) # quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0 assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0 assert out_features % (32 // self.w_bit) == 0
pack_num = (32 // self.w_bit) pack_num = 32 // self.w_bit
self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev)) self.register_buffer(
self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev)) "qweight",
self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev)) torch.zeros(
(out_features, in_features // pack_num), dtype=torch.int32, device=dev
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(out_features, calculate_zeros_width(in_features, self.group_size)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(
out_features,
calculate_zeros_width(in_features, self.group_size) * pack_num,
),
dtype=torch.float16,
device=dev,
),
)
if bias: if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) self.register_buffer(
"bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
)
else: else:
self.bias = None self.bias = None
@classmethod @classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None): def from_linear(
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device) cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd if init_only: # just prepare for loading sd
return awq_linear return awq_linear
...@@ -154,21 +247,33 @@ class WQLinear_GEMV(nn.Module): ...@@ -154,21 +247,33 @@ class WQLinear_GEMV(nn.Module):
pack_num = 32 // awq_linear.w_bit pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros( qscales = torch.zeros(
(scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num), (
scales.shape[0],
calculate_zeros_width(linear.in_features, group_size) * pack_num,
),
dtype=torch.float16, dtype=torch.float16,
device=scales.device device=scales.device,
) )
qscales[:, :scales.shape[1]] = scales qscales[:, : scales.shape[1]] = scales
awq_linear.scales = qscales awq_linear.scales = qscales
if linear.bias is not None: if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half() awq_linear.bias = linear.bias.clone().half()
intweight = [] intweight = []
for idx in range(awq_linear.in_features): for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None]) intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ awq_linear.scales[:, idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1) intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32) intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device) qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num): for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4: if awq_linear.w_bit == 4:
...@@ -202,7 +307,7 @@ class WQLinear_GEMV(nn.Module): ...@@ -202,7 +307,7 @@ class WQLinear_GEMV(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features,)
inputs = x.reshape(-1, x.shape[-1]) inputs = x.reshape(-1, x.shape[-1])
input_dtype = inputs.dtype input_dtype = inputs.dtype
...@@ -210,9 +315,18 @@ class WQLinear_GEMV(nn.Module): ...@@ -210,9 +315,18 @@ class WQLinear_GEMV(nn.Module):
inputs = inputs.half() inputs = inputs.half()
if inputs.shape[0] > 8: if inputs.shape[0] > 8:
out = awq_ext.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters) out = awq_ext.gemmv2_forward_cuda(
inputs,
self.qweight,
self.scales,
self.qzeros,
self.group_size,
self.split_k_iters,
)
else: else:
out = awq_ext.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size) out = awq_ext.gemv_forward_cuda(
inputs, self.qweight, self.scales, self.qzeros, self.group_size
)
if input_dtype != torch.float16: if input_dtype != torch.float16:
out = out.to(dtype=input_dtype) out = out.to(dtype=input_dtype)
...@@ -221,6 +335,12 @@ class WQLinear_GEMV(nn.Module): ...@@ -221,6 +335,12 @@ class WQLinear_GEMV(nn.Module):
return out.reshape(out_shape) return out.reshape(out_shape)
def extra_repr(self) -> str: def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( return (
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size "in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
) )
import torch
AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
izeros.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def pack_exllama(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=iweights.device)
# packing rowwise
iweights = iweights.view(iweights.shape[0] // (32 // bits), 32 // bits, -1)
qweight = (
torch.bitwise_left_shift(iweights, shifts[None, :, None])
.sum(dim=1)
.to(torch.int32)
)
# packing columnwise
izeros = izeros.view(-1, izeros.shape[1] // (32 // bits), 32 // bits)
qzeros = (
torch.bitwise_left_shift(izeros, shifts[None, None, :])
.sum(dim=-1)
.to(torch.int32)
)
return qweight, qzeros
def unpack_reorder_pack(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
# We can remove it if we remove the +1 in the exllama code
izeros = izeros - 1
# Pack the qweight and qzeros tensors
qweight, qzeros = pack_exllama(iweight, izeros, bits)
return qweight, qzeros
import torch import torch
from typing import List from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
def prepare_correct_devices(next_layer, hidden_states, mask): def prepare_correct_devices(next_layer, hidden_states, mask):
...@@ -52,8 +53,12 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): ...@@ -52,8 +53,12 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
if isinstance(q_proj, WQLinear_GEMV): if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV q_linear = WQLinear_GEMV
else: elif isinstance(q_proj, WQLinear_GEMM):
q_linear = WQLinear_GEMM q_linear = WQLinear_GEMM
elif isinstance(q_proj, WQLinear_Exllama):
q_linear = WQLinear_Exllama
else:
q_linear = WQLinear_ExllamaV2
qkv_layer = q_linear( qkv_layer = q_linear(
q_proj.w_bit, q_proj.w_bit,
...@@ -64,12 +69,20 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): ...@@ -64,12 +69,20 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
next(iter(module.state_dict().values())).device next(iter(module.state_dict().values())).device
) )
if isinstance(qkv_layer, WQLinear_GEMV): if isinstance(q_proj, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0) qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0) qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0) qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters qkv_layer.split_k_iters = q_proj.split_k_iters
else: elif isinstance(q_proj, WQLinear_GEMM):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
elif isinstance(q_proj, WQLinear_Exllama):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
elif isinstance(q_proj, WQLinear_ExllamaV2):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-Instruct-v0.1-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, use_exllama_v2=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment