Unverified Commit 34085edc authored by Ilyas Moutawwakil's avatar Ilyas Moutawwakil Committed by GitHub
Browse files

Marlin symmetric quantization and inference (#320)


Co-authored-by: default avatarCasper <casperbh.96@gmail.com>
parent f018d2b7
...@@ -13,6 +13,7 @@ from transformers.modeling_utils import shard_checkpoint ...@@ -13,6 +13,7 @@ from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear.gemm import WQLinear_GEMM from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin, marlin_post_init
from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init from awq.modules.linear.exllama import WQLinear_Exllama, exllama_post_init
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from awq.utils.module import ( from awq.utils.module import (
...@@ -103,6 +104,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -103,6 +104,7 @@ class BaseAWQForCausalLM(nn.Module):
tokenizer, tokenizer,
self.quant_config.w_bit, self.quant_config.w_bit,
self.quant_config.q_group_size, self.quant_config.q_group_size,
self.quant_config.zero_point,
self.quant_config.version, self.quant_config.version,
calib_data, calib_data,
split, split,
...@@ -149,6 +151,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -149,6 +151,7 @@ class BaseAWQForCausalLM(nn.Module):
# Save model and config files with empty state dict # Save model and config files with empty state dict
self.model.config.quantization_config = self.quant_config.to_transformers_dict() self.model.config.quantization_config = self.quant_config.to_transformers_dict()
self.model.generation_config.do_sample = True
self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict()) self.model.save_pretrained(save_dir, state_dict=EmptyModule().state_dict())
self.quant_config.save_pretrained(save_dir) self.quant_config.save_pretrained(save_dir)
...@@ -302,7 +305,10 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -302,7 +305,10 @@ class BaseAWQForCausalLM(nn.Module):
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
if use_exllama: if quant_config.version == "Marlin":
model = marlin_post_init(model)
elif use_exllama:
# creates q4 handle # creates q4 handle
model = exllama_post_init(model) model = exllama_post_init(model)
elif use_exllama_v2: elif use_exllama_v2:
...@@ -375,7 +381,6 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -375,7 +381,6 @@ class BaseAWQForCausalLM(nn.Module):
self, model, quant_config, version, use_exllama, use_exllama_v2 self, model, quant_config, version, use_exllama, use_exllama_v2
): ):
# Real quantization of weights # Real quantization of weights
assert quant_config.zero_point, "We only support zero_point quantization now."
assert not ( assert not (
version == "GEMV" and (use_exllama or use_exllama_v2) version == "GEMV" and (use_exllama or use_exllama_v2)
), "Exllama kernels only support GEMM version." ), "Exllama kernels only support GEMM version."
...@@ -399,7 +404,9 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -399,7 +404,9 @@ class BaseAWQForCausalLM(nn.Module):
# Replace nn.Linear with WQLinear # Replace nn.Linear with WQLinear
for name, module in named_linears.items(): for name, module in named_linears.items():
if use_exllama: if version == "Marlin":
q_linear_module = WQLinear_Marlin
elif use_exllama:
q_linear_module = WQLinear_Exllama q_linear_module = WQLinear_Exllama
elif use_exllama_v2: elif use_exllama_v2:
q_linear_module = WQLinear_ExllamaV2 q_linear_module = WQLinear_ExllamaV2
......
...@@ -2,3 +2,4 @@ from .exllama import WQLinear_Exllama ...@@ -2,3 +2,4 @@ from .exllama import WQLinear_Exllama
from .exllamav2 import WQLinear_ExllamaV2 from .exllamav2 import WQLinear_ExllamaV2
from .gemm import WQLinear_GEMM from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin
\ No newline at end of file
...@@ -109,6 +109,11 @@ class WQLinear_Exllama(nn.Module): ...@@ -109,6 +109,11 @@ class WQLinear_Exllama(nn.Module):
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" "Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
) )
assert EXL_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
input_dtype = x.dtype input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,) out_shape = x.shape[:-1] + (self.out_features,)
......
...@@ -10,7 +10,6 @@ try: ...@@ -10,7 +10,6 @@ try:
except: except:
EXLV2_INSTALLED = False EXLV2_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension # Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta") none_tensor = torch.empty((1, 1), device="meta")
...@@ -23,7 +22,6 @@ class WQLinear_ExllamaV2(nn.Module): ...@@ -23,7 +22,6 @@ class WQLinear_ExllamaV2(nn.Module):
raise NotImplementedError("Only 4-bit are supported for now.") raise NotImplementedError("Only 4-bit are supported for now.")
self.q_handle = None self.q_handle = None
self.q_tensors = None
self.w_bit = w_bit self.w_bit = w_bit
self.in_features = in_features self.in_features = in_features
...@@ -134,8 +132,8 @@ class WQLinear_ExllamaV2(nn.Module): ...@@ -134,8 +132,8 @@ class WQLinear_ExllamaV2(nn.Module):
"Use exllamav2_post_init() on the whole model." "Use exllamav2_post_init() on the whole model."
) )
assert EXLV2_INSTALLED, ( assert EXLV2_INSTALLED, (
"Exllama kernels could not be loaded. " "ExllamaV2 kernels are not installed. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels" "Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
) )
input_dtype = x.dtype input_dtype = x.dtype
......
import torch
import torch.nn as nn
import numpy as np
try:
import marlin_cuda # with CUDA kernels (AutoAWQ_kernels)
MARLIN_INSTALLED = True
except:
MARLIN_INSTALLED = False
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
class WQLinear_Marlin(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
self.max_par = 8 # partitioning for large inputs
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
######################################################
## These shapes are only specific for Marlin models ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features // 16, out_features * 16 // 8),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
######################################################
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls,
linear,
w_bit,
group_size,
init_only=False,
scales=None,
zeros=None,
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
assert zeros is None and scales is not None
tile = 16
maxq = 2**4 - 1
s = scales.t()
w = linear.weight.data.t()
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((-1, awq_linear.group_size, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape((awq_linear.group_size, -1))
s = s.reshape((1, -1))
w = torch.round(w / s).int()
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((awq_linear.group_size, -1, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape(
(awq_linear.in_features, awq_linear.out_features)
).contiguous()
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, awq_linear.out_features)).contiguous()
w = w.reshape(
(
awq_linear.in_features // tile,
tile,
awq_linear.out_features // tile,
tile,
)
)
w = w.permute((0, 2, 1, 3))
w = w.reshape((awq_linear.in_features // tile, awq_linear.out_features * tile))
res = w
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
for i in range(8):
q |= res[:, i::8] << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
awq_linear.qweight[:] = q.to(awq_linear.qweight.device)
awq_linear.scales[:] = s.to(awq_linear.qweight.device)
if awq_linear.bias is not None:
awq_linear.bias[:] = linear.bias.data.to(awq_linear.bias.device)
return awq_linear
def post_init(self):
self.register_buffer(
"workspace",
torch.zeros(
self.out_features // 128 * self.max_par,
dtype=torch.int32,
device=self.qweight.device,
),
persistent=False,
)
@torch.no_grad()
def forward(self, x):
assert hasattr(self, "workspace"), (
"module.post_init() must be called before module.forward(). "
"Use marlin_post_init() on the whole model."
)
assert MARLIN_INSTALLED, (
"Marlin kernels are not installed. "
"Please install AWQ compatible Marlin kernels from AutoAWQ_kernels."
)
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
marlin_cuda.mul(
x,
self.qweight,
out,
self.scales,
self.workspace,
-1, # thread_k
-1, # thread_n
-1, # sms
self.max_par,
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
def marlin_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Marlin):
submodule.post_init()
return model
This diff is collapsed.
import torch import torch
from awq.modules.linear.gemm import WQLinear_GEMM from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV from awq.modules.linear.gemv import WQLinear_GEMV
from awq.modules.linear.marlin import WQLinear_Marlin
from awq.modules.linear.exllama import WQLinear_Exllama from awq.modules.linear.exllama import WQLinear_Exllama
from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2 from awq.modules.linear.exllamav2 import WQLinear_ExllamaV2
...@@ -58,8 +60,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): ...@@ -58,8 +60,10 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
q_linear = WQLinear_GEMM q_linear = WQLinear_GEMM
elif isinstance(q_proj, WQLinear_Exllama): elif isinstance(q_proj, WQLinear_Exllama):
q_linear = WQLinear_Exllama q_linear = WQLinear_Exllama
else: elif isinstance(q_proj, WQLinear_ExllamaV2):
q_linear = WQLinear_ExllamaV2 q_linear = WQLinear_ExllamaV2
elif isinstance(q_proj, WQLinear_Marlin):
q_linear = WQLinear_Marlin
qkv_layer = q_linear( qkv_layer = q_linear(
q_proj.w_bit, q_proj.w_bit,
...@@ -87,7 +91,11 @@ def fuse_qkv(module, q_proj, k_proj, v_proj): ...@@ -87,7 +91,11 @@ def fuse_qkv(module, q_proj, k_proj, v_proj):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1) qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1) qkv_layer.qzeros = torch.cat([q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1) qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
elif isinstance(q_proj, WQLinear_Marlin):
qkv_layer.qweight = torch.cat([q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1)
qkv_layer.scales = torch.cat([q_proj.scales, k_proj.scales, v_proj.scales], dim=1)
# workspace is created in post_init
qkv_layer.bias = bias qkv_layer.bias = bias
return qkv_layer return qkv_layer
......
import torch
from typing import List
Q_BITS = 4
STORAGE_BITS = 32
PACK_NUM = STORAGE_BITS // Q_BITS
ORDINAL_PACK_ORDER = [0, 1, 2, 3, 4, 5, 6, 7]
AWQ_PACK_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
REVERSE_AWQ_PACK_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def pack(imatrix: torch.Tensor, direction: str = "column"):
"""
Packs a 4-bit integer matrix into a packed 32-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of packing, either "column" or "row"
Returns:
qmatrix (torch.Tensor): packed matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=imatrix.device)
imatrix = imatrix.to(torch.int8)
imatrix = torch.bitwise_and(imatrix, 0x0F) # eventually correct overflow
if direction == "column":
imatrix = imatrix.view(-1, imatrix.shape[1] // PACK_NUM, PACK_NUM)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, None, :]).sum(dim=-1)
elif direction == "row":
imatrix = imatrix.view(imatrix.shape[0] // PACK_NUM, PACK_NUM, -1)
qmatrix = torch.bitwise_left_shift(imatrix, shifts[None, :, None]).sum(dim=1)
qmatrix = qmatrix.to(torch.int32)
return qmatrix
def unpack(qmatrix: torch.Tensor, direction: str = "column"):
"""
Unpacks a 32-bit packed integer matrix into a 4-bit integer matrix.
Args:
qmatrix (torch.Tensor): matrix of packed integers
direction (str): direction of unpacking, either "column" or "row"
Returns:
imatrix (torch.Tensor): matrix of integers
"""
shifts = torch.arange(0, STORAGE_BITS, Q_BITS, device=qmatrix.device)
if direction == "column":
imatrix = torch.bitwise_right_shift(
qmatrix[:, :, None], shifts[None, None, :]
).view(qmatrix.shape[0], -1)
elif direction == "row":
imatrix = torch.bitwise_right_shift(
qmatrix[:, None, :], shifts[None, :, None]
).view(-1, qmatrix.shape[-1])
imatrix = imatrix.to(torch.int8) & 0x0F # eventually correct overflow
return imatrix
def quantize(fmatrix, scales, zeros, group_size):
"""
Quantizes a matrix of 16-bit floats into a matrix of 4-bit integers.
Args:
fmatrix (torch.Tensor): matrix of 16-bit floats
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
imatrix (torch.Tensor): matrix of 4-bit integers
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = torch.round(
(
fmatrix / scales.repeat_interleave(group_size, dim=0)
+ zeros.repeat_interleave(group_size, dim=0)
)
)
imatrix = imatrix.to(torch.int8) & 0x0F
return imatrix
def dequantize(imatrix, scales, zeros, group_size):
"""
Dequantizes a 4-bit integer matrix into a float matrix.
Args:
imatrix (torch.Tensor): matrix of 4-bit integers
scales (torch.Tensor): matrix of 16-bit floats
zeros (torch.Tensor): matrix of 4-bit integers
group_size (int): group size
Returns:
fmatrix (torch.Tensor): matrix of 16-bit floats
"""
zeros = zeros.to(torch.int8) & 0x0F
imatrix = imatrix.to(torch.int8) & 0x0F
fmatrix = (
imatrix - zeros.repeat_interleave(group_size, dim=0)
) * scales.repeat_interleave(group_size, dim=0)
fmatrix = fmatrix.to(torch.float16)
return fmatrix
def apply_order(
imatrix: torch.Tensor,
direction: str = "column",
order: List[int] = ORDINAL_PACK_ORDER,
):
"""
Applies the order to a 4-bit integer matrix.
Args:
imatrix (torch.Tensor): matrix of integers
direction (str): direction of applying order, either "column" or "row"
order (List[int]): order to apply, default is ordinal packing order
Returns:
imatrix (torch.Tensor): matrix of integers
"""
if direction == "column":
imatrix = imatrix.view(-1, PACK_NUM)[:, order].view(imatrix.shape)
elif direction == "row":
imatrix = imatrix.view(PACK_NUM, -1)[order, :].view(imatrix.shape)
return imatrix
def awq_to_exllama(qweight, qzeros):
# awq uses column packing for both weights and zeros
izeros = unpack(qzeros, direction="column")
iweights = unpack(qweight, direction="column")
# Reverse the order of the iweight and izeros tensors
izeros = apply_order(izeros, direction="column", order=REVERSE_AWQ_PACK_ORDER)
iweights = apply_order(iweights, direction="column", order=REVERSE_AWQ_PACK_ORDER)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
izeros = izeros - 1
# exllama uses row packing for weights and column packing for zeros
qzeros = pack(izeros, direction="column")
qweight = pack(iweights, direction="row")
return qweight, qzeros
from awq import AutoAWQForCausalLM from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer from transformers import AutoTokenizer, TextStreamer
quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ" quant_path = "TheBloke/Mistral-7B-Instruct-v0.2-AWQ"
# Load model # Load model
......
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer, TextStreamer
quant_path = "IlyasMoutawwakil/vicuna-7b-v1.5-awq-marlin"
# Load model
model = AutoAWQForCausalLM.from_quantized(quant_path, fuse_layers=False)
tokenizer = AutoTokenizer.from_pretrained(quant_path, trust_remote_code=True)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
# Convert prompt to tokens
prompt_template = """\
<|system|>
</s>
<|user|>
{prompt}</s>
<|assistant|>"""
prompt = "You're standing on the surface of the Earth. "\
"You walk one mile south, one mile west and one mile north. "\
"You end up exactly where you started. Where are you?"
tokens = tokenizer(
prompt_template.format(prompt=prompt),
return_tensors='pt'
).input_ids.cuda()
# Generate output
generation_output = model.generate(
tokens,
streamer=streamer,
max_new_tokens=512
)
\ No newline at end of file
from awq import AutoAWQForCausalLM
from transformers import AutoTokenizer
model_path = 'lmsys/vicuna-7b-v1.5'
quant_path = 'vicuna-7b-v1.5-awq-marlin'
quant_config = { "zero_point": False, "q_group_size": 128, "w_bit": 4, "version": "Marlin" }
# Load model
# NOTE: pass safetensors=True to load safetensors
model = AutoAWQForCausalLM.from_pretrained(
model_path, **{"low_cpu_mem_usage": True, "use_cache": False}
)
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
# Quantize
model.quantize(tokenizer, quant_config=quant_config)
# Save quantized model
model.save_quantized(quant_path)
tokenizer.save_pretrained(quant_path)
print(f'Model is quantized and saved at "{quant_path}"')
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment