Unverified Commit 2fcbf268 authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

Exllama kernels support (#313)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent a3db8099
......@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldAquilaDecoderLayer,
LlamaForCausalLM as OldAquilaForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class AquilaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -95,11 +94,6 @@ class AquilaFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
......@@ -114,7 +108,7 @@ class AquilaFuser:
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
......@@ -127,3 +121,5 @@ class AquilaFuser:
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
......@@ -23,39 +23,76 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"llava": LlavaAWQForCausalLM,
}
def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs)
config = AutoConfig.from_pretrained(
model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs
)
if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
return model_type
class AutoAWQForCausalLM:
def __init__(self):
raise EnvironmentError('You must instantiate AutoAWQForCausalLM with\n'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained')
raise EnvironmentError(
"You must instantiate AutoAWQForCausalLM with\n"
"AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained"
)
@classmethod
def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code, **model_init_kwargs)
def from_pretrained(
self,
model_path,
trust_remote_code=True,
safetensors=False,
device_map=None,
**model_init_kwargs,
) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(
model_path, trust_remote_code, **model_init_kwargs
)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors,
device_map=device_map, **model_init_kwargs
model_path,
model_type,
trust_remote_code=trust_remote_code,
safetensors=safetensors,
device_map=device_map,
**model_init_kwargs,
)
@classmethod
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=True,
device_map="balanced", offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM:
def from_quantized(
self,
quant_path,
quant_filename="",
max_new_tokens=None,
trust_remote_code=True,
fuse_layers=True,
use_exllama=False,
use_exllama_v2=False,
batch_size=1,
safetensors=True,
device_map="balanced",
offload_folder=None,
**config_kwargs,
) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors,
device_map=device_map, offload_folder=offload_folder,
**config_kwargs
quant_path,
model_type,
quant_filename,
max_new_tokens,
trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
safetensors=safetensors,
device_map=device_map,
offload_folder=offload_folder,
**config_kwargs,
)
......@@ -5,9 +5,7 @@ from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class BaichuanAWQForCausalLM(BaseAWQForCausalLM):
......@@ -102,11 +100,6 @@ class BaichuanFuser:
# module.self_attn.v_proj
# )
qkv = module.self_attn.W_pack
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.epsilon
......@@ -121,7 +114,7 @@ class BaichuanFuser:
n_kv_heads=self.model.config.num_attention_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
......@@ -135,3 +128,5 @@ class BaichuanFuser:
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import os
import gc
import json
import time
import torch
import torch.nn as nn
......@@ -14,6 +14,8 @@ import transformers
from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.utils.module import (
get_named_linears,
set_op_by_name,
......@@ -34,6 +36,8 @@ from accelerate.big_modeling import (
from awq.models._config import AwqConfig
from awq.modules.act import ScaledActivation
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
from awq.quantize.quantizer import AwqQuantizer
from awq.utils.module import get_named_linears, set_op_by_name
......@@ -59,12 +63,15 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"llava": "AutoModelForVision2Seq",
}
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, config, quant_config, processor):
def __init__(
self, model, model_type, is_quantized, config, quant_config, processor
):
super().__init__()
self.model:PreTrainedModel = model
self.model_type:str = model_type
self.is_quantized:bool = is_quantized
self.model: PreTrainedModel = model
self.model_type: str = model_type
self.is_quantized: bool = is_quantized
self.search_result = None
self.config: PretrainedConfig = config
self.quant_config: AwqConfig = quant_config
......@@ -81,16 +88,32 @@ class BaseAWQForCausalLM(nn.Module):
return self.model.generate(*args, **kwargs)
@torch.no_grad()
def quantize(self, tokenizer=None, quant_config={},
calib_data: Union[str, List[str]]="pileval",
split="train", text_column="text", duo_scaling=True,
modules_to_not_convert=None, export_compatible=False):
def quantize(
self,
tokenizer=None,
quant_config={},
calib_data: Union[str, List[str]] = "pileval",
split="train",
text_column="text",
duo_scaling=True,
modules_to_not_convert=None,
export_compatible=False,
):
self.quant_config: AwqConfig = AwqConfig.from_dict(quant_config)
self.quantizer = AwqQuantizer(
self, self.model, tokenizer, self.quant_config.w_bit, self.quant_config.q_group_size,
self.quant_config.version, calib_data, split, text_column, duo_scaling, modules_to_not_convert=modules_to_not_convert,
export_compatible=export_compatible
self,
self.model,
tokenizer,
self.quant_config.w_bit,
self.quant_config.q_group_size,
self.quant_config.version,
calib_data,
split,
text_column,
duo_scaling,
modules_to_not_convert=modules_to_not_convert,
export_compatible=export_compatible,
)
self.quantizer.quantize()
......@@ -118,12 +141,15 @@ class BaseAWQForCausalLM(nn.Module):
pass
def save_quantized(self, save_dir, safetensors=True, shard_size="10GB"):
save_dir = save_dir[:-1] if save_dir[-1] == '/' else save_dir
save_dir = save_dir[:-1] if save_dir[-1] == "/" else save_dir
# Save model
class EmptyModule(nn.Module):
def __init__(self): super(EmptyModule, self).__init__()
def forward(self, x): return x
def __init__(self):
super(EmptyModule, self).__init__()
def forward(self, x):
return x
# Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict()
......@@ -135,42 +161,51 @@ class BaseAWQForCausalLM(nn.Module):
self.processor.save_pretrained(save_dir)
# Remove empty state dict
default_paths = [f'{save_dir}/model.safetensors', f'{save_dir}/pytorch_model.bin']
default_paths = [
f"{save_dir}/model.safetensors",
f"{save_dir}/pytorch_model.bin",
]
for path in default_paths:
if os.path.exists(path):
os.remove(path)
# model_name has no extension, add it when saving state_dict
model_name = 'model.safetensors' if safetensors else 'pytorch_model.bin'
model_name = "model.safetensors" if safetensors else "pytorch_model.bin"
# shard checkpoint into chunks (10GB default)
shards, index = shard_checkpoint(
self.model.state_dict(),
max_shard_size=shard_size,
weights_name=model_name
self.model.state_dict(), max_shard_size=shard_size, weights_name=model_name
)
for shard_file, shard in shards.items():
if safetensors:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
shard = {k: v.clone().contiguous() for k, v in shard.items()}
save_file(shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"})
save_file(
shard, os.path.join(save_dir, shard_file), metadata={"format": "pt"}
)
else:
torch.save(shard, os.path.join(save_dir, shard_file))
# save shard index
if index is not None:
with open(f'{save_dir}/{model_name}.index.json', 'w+') as file:
with open(f"{save_dir}/{model_name}.index.json", "w+") as file:
file.write(json.dumps(index, indent=4))
@classmethod
def from_pretrained(self, model_path, model_type, torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True, safetensors=False, device_map=None,
**model_init_kwargs):
def from_pretrained(
self,
model_path,
model_type,
torch_dtype: torch.dtype = torch.float16,
trust_remote_code=True,
safetensors=False,
device_map=None,
**model_init_kwargs,
):
# Get weights path and quant config
model_weights_path, config, quant_config = self._load_config(
self, model_path, '', safetensors, trust_remote_code=trust_remote_code
self, model_path, "", safetensors, trust_remote_code=trust_remote_code
)
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
......@@ -188,26 +223,49 @@ class BaseAWQForCausalLM(nn.Module):
torch_dtype=torch_dtype,
use_safetensors=safetensors,
device_map=device_map,
**model_init_kwargs
**model_init_kwargs,
)
model.eval()
return self(model, model_type, is_quantized=False, config=config,
quant_config=quant_config, processor=processor)
return self(
model,
model_type,
is_quantized=False,
config=config,
quant_config=quant_config,
processor=processor,
)
@classmethod
def from_quantized(self, model_path, model_type, model_filename='',
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=True, is_quantized=True,
fuse_layers=False, version='GEMM',
device_map="balanced", offload_folder=None,
**config_kwargs):
def from_quantized(
self,
model_path,
model_type,
model_filename="",
max_new_tokens=None,
torch_dtype=torch.float16,
trust_remote_code=True,
safetensors=True,
is_quantized=True,
fuse_layers=False,
use_exllama=False,
use_exllama_v2=False,
version="GEMM",
device_map="balanced",
offload_folder=None,
**config_kwargs,
):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
self, model_path, model_filename, safetensors, version,
trust_remote_code, max_new_tokens=max_new_tokens,
**config_kwargs
self,
model_path,
model_filename,
safetensors,
version,
trust_remote_code,
max_new_tokens=max_new_tokens,
**config_kwargs,
)
target_cls_name = TRANSFORMERS_AUTO_MAPPING_DICT[config.model_type]
......@@ -215,10 +273,21 @@ class BaseAWQForCausalLM(nn.Module):
# [STEP 3] Load model
with init_empty_weights():
model = target_cls.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
model = target_cls.from_config(
config=config,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
# Prepare WQLinear layers, replace nn.Linear
self._load_quantized_modules(self, model, quant_config, quant_config.version)
self._load_quantized_modules(
self,
model,
quant_config,
quant_config.version,
use_exllama=use_exllama,
use_exllama_v2=use_exllama_v2,
)
model.tie_weights()
......@@ -237,12 +306,37 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers:
self.fuse_layers(model)
return self(model, model_type, is_quantized=is_quantized, config=config,
quant_config=quant_config, processor=None)
if use_exllama:
# creates q4 handle
model = exllama_post_init(model)
elif use_exllama_v2:
# creates q4 handle and allocates scratch spaces wrt max_input_len and
# max_batch_size, which are hardcoded for now but might be worth interfacing
model = exllamav2_post_init(
model,
max_input_len=max_new_tokens,
max_batch_size=int(os.getenv("AWQ_BATCH_SIZE", 1))
)
return self(
model,
model_type,
is_quantized=is_quantized,
config=config,
quant_config=quant_config,
processor=None,
)
def _load_config(self, model_path, model_filename, safetensors=True,
version="GEMM", trust_remote_code=True, max_new_tokens=4096,
**config_kwargs):
def _load_config(
self,
model_path,
model_filename,
safetensors=True,
version="GEMM",
trust_remote_code=True,
max_new_tokens=4096,
**config_kwargs,
):
# [STEP 1] Download model if path is not a directory
if not os.path.isdir(model_path):
ignore_patterns = ["*msgpack*", "*h5*", "optimizer.pt"]
......@@ -253,8 +347,8 @@ class BaseAWQForCausalLM(nn.Module):
model_path = snapshot_download(model_path, ignore_patterns=ignore_patterns)
if model_filename != '':
model_weights_path = model_path + f'/{model_filename}'
if model_filename != "":
model_weights_path = model_path + f"/{model_filename}"
else:
model_weights_path = model_path
......@@ -263,22 +357,33 @@ class BaseAWQForCausalLM(nn.Module):
quant_config = AwqConfig.from_pretrained(model_path)
# Load model config and set max generation length
if max_new_tokens is None and hasattr(self, 'max_new_tokens_key'):
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
if max_new_tokens is None and hasattr(self, "max_new_tokens_key"):
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
# To add the generate support for Multi-modal models as well
if hasattr(config, "text_config"):
config.text_config.max_new_tokens = getattr(config, self.max_new_tokens_key, 2048)
config.text_config.max_new_tokens = getattr(
config, self.max_new_tokens_key, 2048
)
else:
max_new_tokens = 2048 if max_new_tokens is None else max_new_tokens
config = AutoConfig.from_pretrained(model_path, trust_remote_code=trust_remote_code, **config_kwargs)
config = AutoConfig.from_pretrained(
model_path, trust_remote_code=trust_remote_code, **config_kwargs
)
config.max_new_tokens = max_new_tokens
return model_weights_path, config, quant_config
def _load_quantized_modules(self, model, quant_config, version):
def _load_quantized_modules(
self, model, quant_config, version, use_exllama, use_exllama_v2
):
# Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now."
assert not (
version == "GEMV" and (use_exllama or use_exllama_v2)
), "Exllama kernels only support GEMM version."
# Get blocks of model
layers = self.get_model_layers(model)
......@@ -290,23 +395,26 @@ class BaseAWQForCausalLM(nn.Module):
named_linears = get_named_linears(layer)
# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(named_linears, quant_config.modules_to_not_convert)
named_linears = exclude_layers_to_not_quantize(
named_linears, quant_config.modules_to_not_convert
)
# Replace activation functions
self._scale_activations(self, layer)
# Replace nn.Linear with WQLinear
for name, module in named_linears.items():
if version == 'GEMM':
if use_exllama:
q_linear_module = WQLinear_Exllama
elif use_exllama_v2:
q_linear_module = WQLinear_ExllamaV2
elif version == "GEMM":
q_linear_module = WQLinear_GEMM
elif version == 'GEMV':
elif version == "GEMV":
q_linear_module = WQLinear_GEMV
q_linear = q_linear_module.from_linear(
module,
quant_config.w_bit,
quant_config.q_group_size,
True
module, quant_config.w_bit, quant_config.q_group_size, True
)
q_linear.to(next(layer.parameters()).device)
set_op_by_name(layer, name, q_linear)
......@@ -318,13 +426,15 @@ class BaseAWQForCausalLM(nn.Module):
def _scale_activations(self, layer):
scale_dict = self.get_act_for_scaling(layer)
if scale_dict['is_scalable']:
if not isinstance(scale_dict['scale_layer'], ScaledActivation):
if scale_dict["is_scalable"]:
if not isinstance(scale_dict["scale_layer"], ScaledActivation):
param = next(layer.parameters())
# get activation scale
scale_like = torch.ones(scale_dict['scale_shape'], dtype=param.dtype, device=param.device)
scale_like = torch.ones(
scale_dict["scale_shape"], dtype=param.dtype, device=param.device
)
# scale activation
scaled_act = ScaledActivation(scale_dict['scale_layer'], scale_like)
set_op_by_name(layer, scale_dict['scale_name'], scaled_act)
scaled_act = ScaledActivation(scale_dict["scale_layer"], scale_like)
set_op_by_name(layer, scale_dict["scale_name"], scaled_act)
......@@ -109,3 +109,5 @@ class FalconFuser:
self.model.transformer.word_embeddings,
self.model.transformer.ln_f,
)
setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
\ No newline at end of file
......@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
LlamaForCausalLM as OldLlamaForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlamaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -95,11 +94,6 @@ class LlamaFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
......@@ -114,7 +108,7 @@ class LlamaFuser:
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
......@@ -128,3 +122,4 @@ class LlamaFuser:
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
\ No newline at end of file
......@@ -8,7 +8,6 @@ from transformers.models.llama.modeling_llama import (
LlamaDecoderLayer as OldLlamaDecoderLayer,
)
from transformers.models.llava.modeling_llava import LlavaForConditionalGeneration as OldLlavaForConditionalGeneration
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class LlavaAWQForCausalLM(BaseAWQForCausalLM):
......@@ -95,11 +94,6 @@ class LlavaFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
......@@ -114,16 +108,17 @@ class LlavaFuser:
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
max_seq_len=self.model.config.max_new_tokens
))
self.model = LlamaLikeModel(
self.model.model = LlamaLikeModel(
self.model.config.vocab_size,
blocks,
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
......@@ -8,7 +8,6 @@ from transformers.models.mistral.modeling_mistral import (
MistralDecoderLayer as OldMistralDecoderLayer,
MistralForCausalLM as OldMistralForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MistralAWQForCausalLM(BaseAWQForCausalLM):
......@@ -95,11 +94,6 @@ class MistralFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
......@@ -114,7 +108,7 @@ class MistralFuser:
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
......@@ -127,3 +121,4 @@ class MistralFuser:
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
......@@ -8,7 +8,6 @@ from transformers.models.mixtral.modeling_mixtral import (
MixtralDecoderLayer as OldMixtralDecoderLayer,
MixtralForCausalLM as OldMixtralForCausalLM
)
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class MixtralAWQForCausalLM(BaseAWQForCausalLM):
......@@ -98,14 +97,6 @@ class MixtralFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
# Adapt to mixture of experts
for i in range(len(module.block_sparse_moe.experts)):
mlp = QuantFusedMLP(
gate_proj=module.block_sparse_moe.experts[i].w1,
down_proj=module.block_sparse_moe.experts[i].w2,
up_proj=module.block_sparse_moe.experts[i].w3
)
module.block_sparse_moe.experts[i] = mlp
norm_1 = FasterTransformerRMSNorm(
module.input_layernorm.weight,
module.input_layernorm.variance_epsilon
......@@ -134,4 +125,5 @@ class MixtralFuser:
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
......@@ -105,3 +105,5 @@ class MptFuser:
self.model.transformer.wte,
self.model.transformer.norm_f,
)
setattr(self.model.transformer, "blocks", self.model.transformer.blocks)
\ No newline at end of file
......@@ -4,7 +4,6 @@ from .base import BaseAWQForCausalLM
from awq.utils.fused_utils import fuse_qkv
from awq.modules.fused.block import LlamaLikeBlock
from awq.modules.fused.model import LlamaLikeModel
from awq.modules.fused.mlp import QuantFusedMLP
from awq.modules.fused.norm import FasterTransformerRMSNorm
class YiAWQForCausalLM(BaseAWQForCausalLM):
......@@ -90,11 +89,6 @@ class YiFuser:
module.self_attn.k_proj,
module.self_attn.v_proj
)
mlp = QuantFusedMLP(
module.mlp.gate_proj,
module.mlp.down_proj,
module.mlp.up_proj
)
norm_1 = FasterTransformerRMSNorm(
module.ln1.weight,
module.ln1.variance_epsilon
......@@ -109,7 +103,7 @@ class YiFuser:
n_kv_heads=self.model.config.num_key_value_heads,
qkv_layer=qkv,
o_proj=module.self_attn.o_proj,
mlp=mlp,
mlp=module.mlp,
norm_1=norm_1,
norm_2=norm_2,
dev=device,
......@@ -123,3 +117,4 @@ class YiFuser:
self.model.model.embed_tokens,
self.model.model.norm,
)
setattr(self.model.model, "blocks", self.model.model.blocks)
import torch
import torch.nn as nn
from awq.utils.exllama_utils import unpack_reorder_pack
import exl_ext # with CUDA kernels (AutoAWQ_kernels)
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
class WQLinear_Exllama(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for Exllama kernels")
self.q4 = None
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
##################################################################################
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
##################################################################################
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.qweight, self.qzeros = unpack_reorder_pack(
self.qweight, self.qzeros, self.w_bit
)
self.q4 = exl_ext.make_q4(
self.qweight,
self.qzeros,
self.scales,
none_tensor, # g_idx
self.qweight.device.index, # device index
)
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
raise NotImplementedError("Only inference is supported for Exllama kernels")
def forward(self, x):
assert self.q4 is not None, (
"module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model."
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
if input_dtype != torch.float16:
x = x.to(dtype=torch.float16)
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
exl_ext.q4_matmul(x, self.q4, out)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
def exllama_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Exllama):
submodule.post_init()
return model
import torch
import torch.nn as nn
from typing import Dict
from awq.utils.exllama_utils import unpack_reorder_pack
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
class WQLinear_ExllamaV2(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.q_handle = None
self.q_tensors = None
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
##################################################################################
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
##################################################################################
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
def post_init(self, scratch_space: "ScratchSpace"):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.qweight, self.qzeros = unpack_reorder_pack(
self.qweight, self.qzeros, self.w_bit
)
temp_dq_size = self.temp_dq_size()
temp_dq = scratch_space.get_slice(temp_dq_size)
self.q_handle = exlv2_ext.make_q_matrix(
self.qweight,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
self.qzeros,
self.scales,
none_tensor,
temp_dq,
)
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
raise NotImplementedError("Only inference is supported for ExllamaV2 kernels")
def temp_dq_size(self):
"""
Returns the size of the temporary buffer required for the dq kernel.
"""
return self.in_features * self.out_features * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
"""
Returns the size of the temporary buffer required for the fwd kernel.
"""
return self.out_features * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
"""
Returns the size of the fixed scratch space required for the kernel.
"""
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
def forward(self, x):
assert self.q_handle is not None, (
"module.post_init() must be called before module.forward(). "
"Use exllamav2_post_init() on the whole model."
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
if input_dtype != torch.float16:
x = x.to(dtype=torch.float16)
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
exlv2_ext.gemm_half_q_half(x, self.q_handle, out, False)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
class ScratchSpace:
def __init__(self, scratch_bytes, dev):
self.scratch_bytes = scratch_bytes
self.scratch = torch.empty(
self.scratch_bytes // 2,
dtype=torch.float16,
device=dev,
)
def get_slice(self, size_bytes):
size_halfs = next_multiple(size_bytes, 128) // 2
scratch_slice = self.scratch.narrow(0, 0, size_halfs)
return scratch_slice
def exllamav2_post_init(model, max_input_len: int = 2048, max_batch_size: int = 8):
# we search for the maximum number of bytes required for each device's scratch space
fixed_bytes: Dict[torch.device, int] = {}
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_ExllamaV2):
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed(
max_input_len=max_input_len, max_batch_size=max_batch_size
)
fixed_bytes[device] = max(fixed_bytes.get(device, 0), scratch_fixed)
# we allocate a model-persistent scratch space for each device
model.scratch_spaces: Dict[torch.device, ScratchSpace] = {}
for device, scratch_bytes in fixed_bytes.items():
model.scratch_spaces[device] = ScratchSpace(scratch_bytes, device)
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_ExllamaV2):
device = submodule.qweight.device
submodule.post_init(scratch_space=model.scratch_spaces[device])
return model
def next_multiple(x, multiple):
return ((x + multiple - 1) // multiple) * multiple
......@@ -69,7 +69,7 @@ class LlamaLikeModel(nn.Module):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[LlamaLikeBlock] = blocks
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
......
import math
import torch
import torch.nn as nn
import awq_ext # with CUDA kernels
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
......@@ -21,6 +22,7 @@ def calculate_zeros_width(in_features, group_size=128, pack_num=8):
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
class WQLinear_GEMM(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
......@@ -37,17 +39,54 @@ class WQLinear_GEMM(nn.Module):
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev))
self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev))
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
......@@ -63,11 +102,20 @@ class WQLinear_GEMM(nn.Module):
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None])
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
......@@ -80,7 +128,11 @@ class WQLinear_GEMM(nn.Module):
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device)
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
......@@ -96,13 +148,15 @@ class WQLinear_GEMM(nn.Module):
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
out = awq_ext.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
......@@ -111,8 +165,14 @@ class WQLinear_GEMM(nn.Module):
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
......@@ -132,19 +192,52 @@ class WQLinear_GEMV(nn.Module):
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = (32 // self.w_bit)
pack_num = 32 // self.w_bit
self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev))
self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev))
self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev))
self.register_buffer(
"qweight",
torch.zeros(
(out_features, in_features // pack_num), dtype=torch.int32, device=dev
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(out_features, calculate_zeros_width(in_features, self.group_size)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(
out_features,
calculate_zeros_width(in_features, self.group_size) * pack_num,
),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev))
self.register_buffer(
"bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
)
else:
self.bias = None
@classmethod
def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None):
awq_linear = cls(w_bit, group_size, linear.in_features, linear.out_features, linear.bias is not None, linear.weight.device)
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
......@@ -154,21 +247,33 @@ class WQLinear_GEMV(nn.Module):
pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num),
(
scales.shape[0],
calculate_zeros_width(linear.in_features, group_size) * pack_num,
),
dtype=torch.float16,
device=scales.device
device=scales.device,
)
qscales[:, :scales.shape[1]] = scales
qscales[:, : scales.shape[1]] = scales
awq_linear.scales = qscales
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None])
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ awq_linear.scales[:, idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device)
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
......@@ -202,7 +307,7 @@ class WQLinear_GEMV(nn.Module):
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out_shape = x.shape[:-1] + (self.out_features,)
inputs = x.reshape(-1, x.shape[-1])
input_dtype = inputs.dtype
......@@ -210,9 +315,18 @@ class WQLinear_GEMV(nn.Module):
inputs = inputs.half()
if inputs.shape[0] > 8:
out = awq_ext.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
out = awq_ext.gemmv2_forward_cuda(
inputs,
self.qweight,
self.scales,
self.qzeros,
self.group_size,
self.split_k_iters,
)
else:
out = awq_ext.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)
out = awq_ext.gemv_forward_cuda(
inputs, self.qweight, self.scales, self.qzeros, self.group_size
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
......@@ -221,6 +335,12 @@ class WQLinear_GEMV(nn.Module):
return out.reshape(out_shape)
def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
import torch
AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
izeros.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def pack_exllama(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=iweights.device)
# packing rowwise
iweights = iweights.view(iweights.shape[0] // (32 // bits), 32 // bits, -1)
qweight = (
torch.bitwise_left_shift(iweights, shifts[None, :, None])
.sum(dim=1)
.to(torch.int32)
)
# packing columnwise
izeros = izeros.view(-1, izeros.shape[1] // (32 // bits), 32 // bits)
qzeros = (
torch.bitwise_left_shift(izeros, shifts[None, None, :])
.sum(dim=-1)
.to(torch.int32)
)
return qweight, qzeros
def unpack_reorder_pack(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
# We can remove it if we remove the +1 in the exllama code
izeros = izeros - 1
# Pack the qweight and qzeros tensors
qweight, qzeros = pack_exllama(iweight, izeros, bits)
return qweight, qzeros
import torch
from typing import List
from awq.modules.exllama import WQLinear_Exllama
from awq.modules.exllamav2 import WQLinear_ExllamaV2
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
def prepare_correct_devices(next_layer, hidden_states, mask):
......@@ -52,8 +53,12 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
else:
elif isinstance(q_proj, WQLinear_GEMM):
q_linear = WQLinear_GEMM
elif isinstance(q_proj, WQLinear_Exllama):
q_linear = WQLinear_Exllama
else:
q_linear = WQLinear_ExllamaV2
qkv_layer = q_linear(
q_proj.w_bit,
......@@ -64,12 +69,20 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
next(iter(module.state_dict().values())).device
)
if isinstance(qkv_layer, WQLinear_GEMV):
if isinstance(q_proj, WQLinear_GEMV):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=0)
qkv_layer.split_k_iters = q_proj.split_k_iters
else:
elif isinstance(q_proj, WQLinear_GEMM):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
elif isinstance(q_proj, WQLinear_Exllama):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
elif isinstance(q_proj, WQLinear_ExllamaV2):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-Instruct-v0.1-AWQ"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=True, use_exllama_v2=True)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = "[INST] {prompt} [/INST]"
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment