Commit fcd9637c authored by gaoqiong's avatar gaoqiong
Browse files

Merge branch 'v0.2.5_develop' into 'main'

v0.2.5

See merge request dcutoolkit/deeplearing/autoawq!2
parents 7724cca1 427f5481
import torch
class WindowedCache:
def __init__(self, cache_v_shape, cache_k_shape, max_seq_len, device):
"""
The window size is the same as the max_seq_len. The window will
automatically roll once max_seq_len is exceeded.
"""
# [batch_size, n_kv_heads, max_seq_len, head_dim]
self.v = torch.zeros(cache_v_shape).to(device).half()
# [batch_size, n_kv_heads, head_dim // pack_factor, max_seq_len, pack_factor]
self.k = torch.zeros(cache_k_shape).to(device).half()
self.max_seq_len = max_seq_len
def get_kv(self, batch_size, start_pos, seqlen, head_dim):
"""
Gets the key-value store in correct shapes.
"""
xv = (
self.v[:batch_size, :, : start_pos + seqlen, :].transpose(1, 2).contiguous()
)
xk = (
self.k[:batch_size, :, :, : start_pos + seqlen, :]
.transpose(2, 3)
.contiguous()
)
xk = xk.reshape(xk.shape[:-2] + (head_dim,)).transpose(1, 2).contiguous()
return xv, xk
def update_kv(self, values_store, keys_store, batch_size, start_pos, seqlen):
"""
Updates the values in the key-value store.
"""
self.v[:batch_size, :, start_pos : start_pos + seqlen, :] = values_store
self.k[:batch_size, :, :, start_pos : start_pos + seqlen, :] = keys_store
def roll_kv_n_steps(self, start_pos, n=100):
"""
Roll cache n to the left.
"""
n = min(n, self.max_seq_len)
# Roll cache to the left
self.v = torch.roll(self.v, shifts=-n, dims=2)
self.k = torch.roll(self.k, shifts=-n, dims=3)
# Zero out the new part
self.v[:, :, -n:, :] = 0
self.k[:, :, :, -n:, :] = 0
return start_pos - n
def to(self, device):
self.k = self.k.to(device)
self.v = self.v.to(device)
def increase_batch_size(self, to_bsz):
"""Dynamically allocate new kv when batch size changes."""
self.v = torch.zeros(
to_bsz, *self.v.shape[1:], dtype=self.v.dtype, device=self.v.device
)
self.k = torch.zeros(
to_bsz, *self.k.shape[1:], dtype=self.k.dtype, device=self.k.device
)
def decrease_batch_size(self, to_bsz):
"""Dynamically remove part of cache if batch size changes."""
self.v = self.v[:to_bsz, :, :, :]
self.k = self.k[:to_bsz, :, :, :, :]
import torch.nn as nn
import torch.nn.functional as F
from awq.modules.linear.gemm import WQLinear_GEMM
from awq.modules.linear.gemv import WQLinear_GEMV
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class QuantFusedMLP(nn.Module):
def __init__(
self,
gate_proj,
down_proj,
up_proj,
activation=F.silu,
):
super().__init__()
self.register_buffer("gate_proj_qweight", gate_proj.qweight)
self.register_buffer("gate_proj_scales", gate_proj.scales)
self.register_buffer("gate_proj_qzeros", gate_proj.qzeros)
self.register_buffer("up_proj_qweight", up_proj.qweight)
self.register_buffer("up_proj_scales", up_proj.scales)
self.register_buffer("up_proj_qzeros", up_proj.qzeros)
self.in_features = gate_proj.in_features
self.intermediate_size = gate_proj.out_features
self.out_features = down_proj.out_features
self.w_bit = gate_proj.w_bit
self.down_proj = down_proj
if isinstance(down_proj, WQLinear_GEMV):
self.linear = awq_ext.gemv_forward_cuda
self.group_size = down_proj.group_size
else:
self.linear = awq_ext.gemm_forward_cuda
self.group_size = 8
self.activation = activation
def forward(self, x, routing_weights=None):
out_shape = x.shape[:-1] + (self.intermediate_size,)
x = x.reshape(-1, x.shape[-1])
gate_output = self.linear(
x,
self.gate_proj_qweight,
self.gate_proj_scales,
self.gate_proj_qzeros,
self.group_size,
)
up_output = self.linear(
x,
self.up_proj_qweight,
self.up_proj_scales,
self.up_proj_qzeros,
self.group_size,
)
x = self.activation(gate_output) * up_output
x = x.reshape(out_shape)
x = self.down_proj(x)
if routing_weights is not None:
x = routing_weights * x
return x
class QuantLlamaMLP(QuantFusedMLP):
r"""
QuantLlamaMLP class kept for backward compatibilty, in the future, users
should always use `QuantFusedMLP` class instead.
"""
def __init__(self, gate_proj, down_proj, up_proj):
super().__init__(gate_proj, down_proj, up_proj)
import torch
import torch.nn as nn
from typing import List
from awq.utils import fused_utils
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
MoeModelOutputWithPast,
)
from awq.modules.fused.block import (
MPTBlock,
FalconDecoderLayer,
LlamaLikeBlock,
MixtralBlock,
)
class MixtralModel(nn.Module):
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[MixtralBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm(h)
return MoeModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
router_logits=(),
)
class LlamaLikeModel(nn.Module):
"""
LlamaLikeModel is intended to be reused across models that have
an architecture that closely resembles Llama, e.g. Mistral and Aquila.
"""
def __init__(self, vocab_size, blocks, embedding, norm):
super().__init__()
self.vocab_size = vocab_size
self.embedding = embedding
self.blocks: List[LlamaLikeBlock] = nn.ModuleList(blocks)
self.norm = norm
self.last_forward_num_tokens = 0
@property
def embed_tokens(self):
return self.embedding
@property
def layers(self):
return self.blocks
@torch.inference_mode()
def forward(
self,
input_ids: torch.Tensor,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, _ = layer(
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm(h)
return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=None,
hidden_states=(),
attentions=(),
)
class MPTModel(nn.Module):
def __init__(self, vocab_size, blocks, wte, norm_f):
super().__init__()
self.vocab_size = vocab_size
self.wte = wte
self.blocks: List[MPTBlock] = nn.ModuleList(blocks)
self.norm_f = norm_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(
self,
input_ids,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.wte(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.norm_f(h)
return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
class FalconModel(nn.Module):
def __init__(self, vocab_size, blocks, word_embeddings, ln_f):
super().__init__()
self.vocab_size = vocab_size
self.word_embeddings = word_embeddings
self.blocks: List[FalconDecoderLayer] = nn.ModuleList(blocks)
self.ln_f = ln_f
self.attn_uses_sequence_id = False
self.prefix_lm = False
self.last_forward_num_tokens = 0
@torch.inference_mode()
def forward(
self,
input_ids,
attn_bias=None,
attention_mask=None,
is_causal=None,
*args,
**kwargs,
):
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.word_embeddings(input_ids)
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
type_as=h,
)
for layer in self.blocks:
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h, None, attention_mask=mask, is_causal=is_causal
)
h = self.ln_f(h)
return BaseModelOutputWithPast(
last_hidden_state=h,
past_key_values=past_key_value,
hidden_states=(),
attentions=(),
)
import torch
from typing import Dict
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class FusedSparseMoeBlock(torch.nn.Module):
def __init__(
self,
top_k,
gate,
ws,
w2s,
):
super().__init__()
self.gate = gate
self.top_k = top_k
self.ws = ws
self.w2s = w2s
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (batch * sequence_length, n_experts)
router_logits = self.gate(hidden_states)
final_hidden_states = apply_moe_weights(
self.ws,
self.w2s,
hidden_states,
router_logits,
self.top_k,
renormalize=True,
)
return final_hidden_states.view(batch_size, sequence_length, hidden_dim)
def apply_moe_weights(
w1: Dict[str, torch.Tensor],
w2: Dict[str, torch.Tensor],
x: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
) -> torch.Tensor:
topk_weights, topk_ids = fused_topk(gating_output, topk, renormalize)
(sorted_token_ids, expert_ids, num_tokens_post_padded) = moe_align_block_size(
topk_ids, 16, w1.qweight.shape[0]
)
x = x.view(x.shape[0], 1, *x.shape[1:])
gate_up = awq_ext.grouped_gemm_forward(
x,
w1.qweight,
w1.scales,
w1.qzeros,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
8,
)
out = torch.empty(
(gate_up.shape[:-1] + (gate_up.shape[-1] // 2,)), dtype=x.dtype, device=x.device
)
awq_ext.silu_and_mul(out, gate_up)
out = awq_ext.grouped_gemm_forward(
out,
w2.qweight,
w2.scales,
w2.qzeros,
topk_weights,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
8,
)
return torch.sum(out, dim=1)
def moe_align_block_size(topk_ids: torch.Tensor, block_size: int, num_experts: int):
"""
Aligns the token distribution across experts to be compatible with block size for matrix multiplication.
Parameters:
- topk_ids: A tensor of shape [total_tokens, top_k] representing the top-k expert indices for each token.
- block_size: The block size used in block matrix multiplication.
- num_experts: The total number of experts.
Returns:
- sorted_token_ids: A tensor containing the sorted token indices according to their allocated expert.
- expert_ids: A tensor indicating the assigned expert index for each block.
- num_tokens_post_padded: The total number of tokens after padding, ensuring divisibility by block_size.
This function pads the number of tokens that each expert needs to process so that it is divisible by block_size.
Padding ensures that during block matrix multiplication, the dimensions align correctly.
Example:
Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]], block_size = 4, and num_experts = 4:
- We initially have 12 tokens (after repeating 'top_k' times) and 4 experts, with each expert needing to process 3 tokens.
- As block_size is 4, we pad 1 token for each expert.
- First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
- Then append padding tokens [12, 12, 12, 12] for each block.
- After sorting by expert index, we obtain token_ids [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
Tokens 12 are non-existent (padding) and are ignored in the subsequent matrix multiplication.
- The padding ensures that the total number of tokens is now divisible by block_size for proper block matrix operations.
"""
sorted_ids = torch.empty(
(topk_ids.numel() + num_experts * (block_size - 1),),
dtype=torch.int32,
device=topk_ids.device,
)
expert_ids = torch.empty(
(topk_ids.numel() + num_experts,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
awq_ext.moe_alig_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
)
return sorted_ids, expert_ids, num_tokens_post_pad
def fused_topk(
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
"""Compute top-k indice and weights from gating logits
Args:
gating_output (torch.Tensor): The output of the gating operation (before softmax).
topk (int): The number of top-k experts to select.
renormalize (bool): If True, renormalize the top-k weights to sum to 1.
"""
M = gating_output.shape[0]
if torch.version.hip is not None:
# The MoE kernels are not yet supported on ROCm.
routing_weights = torch.softmax(gating_output, dim=-1, dtype=torch.float32)
topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
else:
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=gating_output.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=gating_output.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=gating_output.device
)
awq_ext.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
import torch
from torch import nn
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
class FasterTransformerRMSNorm(nn.Module):
def __init__(self, weight, eps=1e-6):
super().__init__()
self.weight = weight
self.variance_epsilon = eps
def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
output = torch.empty_like(x)
awq_ext.layernorm_forward_cuda(x, self.weight, output, self.variance_epsilon)
return output
from .exllama import WQLinear_Exllama, exllama_post_init
from .exllamav2 import WQLinear_ExllamaV2, exllamav2_post_init
from .gemm import WQLinear_GEMM
from .gemv import WQLinear_GEMV
from .marlin import WQLinear_Marlin, marlin_post_init
from .gemv_fast import WQLinear_GEMVFast
import torch
import torch.nn as nn
from awq.utils.packing_utils import unpack_reorder_pack
try:
import exl_ext # with CUDA kernels (AutoAWQ_kernels)
EXL_INSTALLED = True
except:
EXL_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
class WQLinear_Exllama(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for Exllama kernels")
self.q4 = None
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
##################################################################################
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
##################################################################################
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
def post_init(self):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.qweight, self.qzeros = unpack_reorder_pack(
self.qweight, self.qzeros, self.w_bit
)
self.q4 = exl_ext.make_q4(
self.qweight,
self.qzeros,
self.scales,
none_tensor, # g_idx
self.qweight.device.index, # device index
)
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
raise NotImplementedError("Only inference is supported for Exllama kernels")
def forward(self, x):
assert self.q4 is not None, (
"module.post_init() must be called before module.forward(). "
"Use exllama_post_init() on the whole model."
)
assert EXL_INSTALLED, (
"Exllama kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
assert EXL_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
if input_dtype != torch.float16:
x = x.to(dtype=torch.float16)
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
exl_ext.q4_matmul(x, self.q4, out)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
def exllama_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Exllama):
submodule.post_init()
return model
import torch
import torch.nn as nn
from typing import Dict
from awq.utils.packing_utils import unpack_reorder_pack
try:
import exlv2_ext # with CUDA kernels (AutoAWQ_kernels)
EXLV2_INSTALLED = True
except:
EXLV2_INSTALLED = False
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device="meta")
class WQLinear_ExllamaV2(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.q_handle = None
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
##################################################################################
## These shapes are only for compatibility with the state_dict of WQLinear_GEMM ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
##################################################################################
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
def post_init(self, scratch_space: "ScratchSpace"):
assert self.qweight.device.type == "cuda"
assert self.qweight.device.index is not None
self.qweight, self.qzeros = unpack_reorder_pack(
self.qweight, self.qzeros, self.w_bit
)
temp_dq_size = self.temp_dq_size()
temp_dq = scratch_space.get_slice(temp_dq_size)
self.q_handle = exlv2_ext.make_q_matrix(
self.qweight,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
none_tensor,
self.qzeros,
self.scales,
none_tensor,
temp_dq,
)
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
raise NotImplementedError("Only inference is supported for ExllamaV2 kernels")
def temp_dq_size(self):
"""
Returns the size of the temporary buffer required for the dq kernel.
"""
return self.in_features * self.out_features * 2 + 128
def temp_fwd_size(self, max_input_len, max_batch_size):
"""
Returns the size of the temporary buffer required for the fwd kernel.
"""
return self.out_features * max_input_len * max_batch_size * 4 + 128
def scratch_space_fixed(self, max_input_len=2048, max_batch_size=8):
"""
Returns the size of the fixed scratch space required for the kernel.
"""
return self.temp_dq_size() + self.temp_fwd_size(max_input_len, max_batch_size)
def forward(self, x):
assert self.q_handle is not None, (
"module.post_init() must be called before module.forward(). "
"Use exllamav2_post_init() on the whole model."
)
assert EXLV2_INSTALLED, (
"ExllamaV2 kernels are not installed. "
"Please install AWQ compatible ExllamaV2 kernels from AutoAWQ_kernels."
)
input_dtype = x.dtype
out_shape = x.shape[:-1] + (self.out_features,)
if input_dtype != torch.float16:
x = x.to(dtype=torch.float16)
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
exlv2_ext.gemm_half_q_half(x, self.q_handle, out, False)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
class ScratchSpace:
def __init__(self, scratch_bytes, dev):
self.scratch_bytes = scratch_bytes
self.scratch = torch.empty(
self.scratch_bytes // 2,
dtype=torch.float16,
device=dev,
)
def get_slice(self, size_bytes):
size_halfs = next_multiple(size_bytes, 128) // 2
scratch_slice = self.scratch.narrow(0, 0, size_halfs)
return scratch_slice
def exllamav2_post_init(model, max_input_len: int = 2048, max_batch_size: int = 8):
# we search for the maximum number of bytes required for each device's scratch space
fixed_bytes: Dict[torch.device, int] = {}
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_ExllamaV2):
device = submodule.qweight.device
scratch_fixed = submodule.scratch_space_fixed(
max_input_len=max_input_len, max_batch_size=max_batch_size
)
fixed_bytes[device] = max(fixed_bytes.get(device, 0), scratch_fixed)
# we allocate a model-persistent scratch space for each device
model.scratch_spaces: Dict[torch.device, ScratchSpace] = {}
for device, scratch_bytes in fixed_bytes.items():
model.scratch_spaces[device] = ScratchSpace(scratch_bytes, device)
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_ExllamaV2):
device = submodule.qweight.device
submodule.post_init(scratch_space=model.scratch_spaces[device])
return model
def next_multiple(x, multiple):
return ((x + multiple - 1) // multiple) * multiple
import torch
import torch.nn as nn
from torch.autograd import Function
from awq.utils.utils import get_best_device
from awq.utils.packing_utils import dequantize_gemm
try:
import awq_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
# Adapted from https://github.com/compressa-ai/AutoAWQ/tree/dev
class WQLinearMMFunction(Function):
@staticmethod
# ctx is the first argument to forward
def forward(
ctx,
x,
qweight,
qzeros,
scales,
w_bit=4,
group_size=128,
bias=None,
out_features=0,
):
# The forward pass can use ctx.
ctx.save_for_backward(x, qweight, qzeros, scales, bias)
ctx.out_features = out_features
out_shape = x.shape[:-1] + (out_features,)
x = x.to(torch.float16)
if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0] * x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 0, 0, 0, False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), qweight, scales, qzeros, 8
)
else:
out = dequantize_gemm(qweight, qzeros, scales, w_bit, group_size)
out = torch.matmul(x, out)
out = out + bias if bias is not None else out
out = out.reshape(out_shape)
# always want 3D tensor if tensor is 2D
if len(out.shape) == 2:
out = out.unsqueeze(0)
return out
@staticmethod
def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors
if not AWQ_INSTALLED:
raise ValueError(
"auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
)
# Cast to correct dtype for mixed precision training
weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False
).to(grad_output.dtype)
if ctx.needs_input_grad[0]:
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
# to propagate gradient across all batch sizes.
batch_size = grad_output.shape[0]
grad_input = grad_output.bmm(weights.transpose(0, 1).unsqueeze(0).repeat(batch_size, 1, 1))
return grad_input, None, None, None, None, None, None, None
class WQLinear_GEMM(nn.Module):
def __init__(
self, w_bit, group_size, in_features, out_features, bias, dev, training=False
):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.training = training
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
self.register_buffer(
"qweight",
torch.zeros(
(in_features, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(in_features // self.group_size, out_features // (32 // self.w_bit)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // self.group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
awq_linear.scales = scales.clone().half()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
pack_num = 32 // awq_linear.w_bit
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[idx // group_size])
/ awq_linear.scales[idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.to(dtype=torch.int32)
best_device = get_best_device()
# Avoid: The operator 'aten::__lshift__.Scalar' is not currently implemented for the MPS device
if "mps" in best_device:
intweight = intweight.to("cpu")
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32, device=best_device)
if "mps" in best_device:
zeros = zeros.to("cpu")
qzeros = torch.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=zeros.device,
)
for col in range(zeros.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 2, 4, 6, 1, 3, 5, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
if self.training:
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)
else:
with torch.no_grad():
out = WQLinearMMFunction.apply(
x,
self.qweight,
self.qzeros,
self.scales,
self.w_bit,
self.group_size,
self.bias,
self.out_features,
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
return out.reshape(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
import torch
import torch.nn as nn
try:
import awq_ext # with CUDA kernels
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
class WQLinear_GEMV(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.split_k_iters = 8
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = 32 // self.w_bit
self.register_buffer(
"qweight",
torch.zeros(
(out_features, in_features // pack_num), dtype=torch.int32, device=dev
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(out_features, calculate_zeros_width(in_features, self.group_size)),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(
out_features,
calculate_zeros_width(in_features, self.group_size) * pack_num,
),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(
scales.shape[0],
calculate_zeros_width(linear.in_features, group_size) * pack_num,
),
dtype=torch.float16,
device=scales.device,
)
qscales[:, : scales.shape[1]] = scales
awq_linear.scales = qscales
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ awq_linear.scales[:, idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
qweight = torch.zeros(
(intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit),
dtype=torch.int32,
device=intweight.device,
)
for col in range(intweight.shape[1] // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
qweight_col = intweight[:, col * pack_num + order_map[i]]
qweight[:, col] |= qweight_col << (i * awq_linear.w_bit)
awq_linear.qweight = qweight
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros(
(zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)),
dtype=torch.int32,
device=zeros.device,
)
for col in range((zeros.shape[1] + pack_num - 1) // pack_num):
if awq_linear.w_bit == 4:
order_map = [0, 1, 2, 3, 4, 5, 6, 7]
else:
raise NotImplementedError("Only 4-bit are supported for now.")
for i in range(pack_num):
if col * pack_num + order_map[i] >= zeros.shape[1]:
continue
qzero_col = zeros[:, col * pack_num + order_map[i]]
qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit)
awq_linear.qzeros = qzeros
return awq_linear
@torch.no_grad()
def forward(self, x):
assert AWQ_INSTALLED, (
"AWQ kernels could not be loaded. "
"Please install them from https://github.com/casper-hansen/AutoAWQ_kernels"
)
out_shape = x.shape[:-1] + (self.out_features,)
inputs = x.reshape(-1, x.shape[-1])
input_dtype = inputs.dtype
if input_dtype != torch.float16:
inputs = inputs.half()
if inputs.shape[0] > 8:
out = awq_ext.gemmv2_forward_cuda(
inputs,
self.qweight,
self.scales,
self.qzeros,
self.group_size,
self.split_k_iters,
)
else:
out = awq_ext.gemv_forward_cuda(
inputs, self.qweight, self.scales, self.qzeros, self.group_size
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
import torch
try:
import awq_v2_ext # with CUDA kernels (AutoAWQ_kernels)
AWQ_INSTALLED = True
except:
AWQ_INSTALLED = False
def make_divisible(c, divisor):
return (c + divisor - 1) // divisor
def calculate_zeros_width(in_features, group_size=128, pack_num=8):
if group_size >= 128:
size_multiplier = 1
elif group_size == 64:
size_multiplier = 2
elif group_size == 32:
size_multiplier = 4
else:
raise NotImplementedError
base_width = make_divisible(in_features // group_size, pack_num)
base_width = make_divisible(base_width, size_multiplier) * size_multiplier
return base_width
def pack_intweight(unpacked_qweight, interleave, kstride):
# unpacked_qweight: [N, K]
N = unpacked_qweight.shape[0]
K = unpacked_qweight.shape[1]
Packed_Kernel = unpacked_qweight.cpu().numpy().reshape(N, K // 32, 32)
# np.arange(32).reshape(4, 4, 2).transpose(1, 0, 2) => [0, 1, 8, 9, 16, 17, 24, 25, ...]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 3, 2, 4)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 32)
# reorder each 8 weights for fast dequantization
# [0, 1, 2, 3, 4, 5, 6, 7] => [0, 2, 4, 6, 1, 3, 5, 7]
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 8)
Packed_Kernel = Packed_Kernel.reshape(N, K // 32, 4, 4, 2).transpose(0, 1, 2, 4, 3)
Packed_Kernel = Packed_Kernel.reshape(N, K)
# interleaving every four rows
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, interleave, K // kstride, kstride
)
# N // 4, K // 64, 4, 64
Packed_Kernel = Packed_Kernel.transpose(0, 2, 1, 3)
Packed_Kernel = Packed_Kernel.reshape(
N // interleave, K // kstride, kstride, interleave
)
# Packing -> (N // 4, K // 64, 64)
Packed_Kernel = (
Packed_Kernel[..., 0]
| (Packed_Kernel[..., 1] << 4)
| (Packed_Kernel[..., 2] << 8)
| (Packed_Kernel[..., 3] << 12)
)
# reshape to (N // 4, K), FP16 format
Packed_Kernel = Packed_Kernel.reshape(N // interleave, K)
qweight = (
torch.tensor(Packed_Kernel.astype("int16"))
.to(unpacked_qweight.device)
.contiguous()
)
return qweight
class WQLinear_GEMVFast(torch.nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.w_bit = w_bit
self.group_size = group_size if group_size != -1 else in_features
self.split_k_iters = 8
self.interleave = 4
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
pack_num = 32 // self.w_bit
int16_pack_num = 16 // self.w_bit
assert out_features % (self.interleave) == 0
self.register_buffer(
"qweight",
torch.zeros(
(
out_features // self.interleave,
in_features // int16_pack_num * self.interleave,
),
dtype=torch.int16,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(
calculate_zeros_width(in_features, self.group_size) * pack_num,
out_features,
),
dtype=torch.float16,
device=dev,
),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
calculate_zeros_width(in_features, self.group_size) * pack_num,
out_features,
),
dtype=torch.float16,
device=dev,
),
)
if bias:
self.register_buffer(
"bias", torch.zeros((out_features), dtype=torch.float16, device=dev)
)
else:
self.bias = None
@classmethod
def from_linear(
cls, linear, w_bit, group_size, init_only=False, scales=None, zeros=None
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only:
return awq_linear
# need scales and zeros info for real quantization
assert scales is not None and zeros is not None
scale_zeros = zeros * scales
pack_num = 32 // awq_linear.w_bit
qscales = torch.zeros(
(
scales.shape[0],
calculate_zeros_width(linear.in_features, group_size) * pack_num,
),
dtype=torch.float16,
device=scales.device,
)
qscales[:, : scales.shape[1]] = scales
# awq_linear.scales = scales.clone().half()
awq_linear.scales = qscales.transpose(1, 0).contiguous()
if linear.bias is not None:
awq_linear.bias = linear.bias.clone().half()
intweight = []
for idx in range(awq_linear.in_features):
intweight.append(
torch.round(
(linear.weight.data[:, idx] + scale_zeros[:, idx // group_size])
/ qscales[:, idx // group_size]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.to(dtype=torch.int32)
awq_linear.qweight = pack_intweight(
intweight.contiguous(), interleave=4, kstride=64
)
zeros = zeros.to(dtype=torch.int32)
qzeros = torch.zeros_like(qscales)
qzeros[:, : scales.shape[1]] = -(
qscales[:, : scales.shape[1]] * (zeros.to(torch.float32))
).to(torch.float16)
awq_linear.qzeros = qzeros.transpose(1, 0).contiguous()
return awq_linear
@torch.no_grad()
def forward(self, x):
inputs = x
batch_size, n_tokens, _ = inputs.shape
if batch_size < 8 and n_tokens == 1:
out = awq_v2_ext.gemv_forward_cuda_decode(
inputs,
self.qweight,
self.scales,
self.qzeros,
inputs.numel() // inputs.shape[-1],
self.out_features,
self.in_features,
self.group_size,
)
else:
out = awq_v2_ext.gemm_forward_cuda_prefill(
inputs, self.qweight, self.scales, self.qzeros
)
out = out + self.bias if self.bias is not None else out
return out
import torch
import torch.nn as nn
import numpy as np
try:
import marlin_cuda # with CUDA kernels (AutoAWQ_kernels)
MARLIN_INSTALLED = True
except:
MARLIN_INSTALLED = False
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
class WQLinear_Marlin(nn.Module):
def __init__(self, w_bit, group_size, in_features, out_features, bias, dev):
super().__init__()
if w_bit not in [4]:
raise NotImplementedError("Only 4-bit are supported for now.")
self.w_bit = w_bit
self.in_features = in_features
self.out_features = out_features
self.group_size = group_size if group_size != -1 else in_features
self.max_par = 8 # partitioning for large inputs
# quick sanity check (make sure aligment)
assert self.in_features % self.group_size == 0
assert out_features % (32 // self.w_bit) == 0
######################################################
## These shapes are only specific for Marlin models ##
self.register_buffer(
"qweight",
torch.zeros(
(in_features // 16, out_features * 16 // 8),
dtype=torch.int32,
device=dev,
),
)
self.register_buffer(
"scales",
torch.zeros(
(in_features // group_size, out_features),
dtype=torch.float16,
device=dev,
),
)
######################################################
if bias:
self.register_buffer(
"bias",
torch.zeros(
(out_features),
dtype=torch.float16,
device=dev,
),
)
else:
self.bias = None
@classmethod
def from_linear(
cls,
linear,
w_bit,
group_size,
init_only=False,
scales=None,
zeros=None,
):
awq_linear = cls(
w_bit,
group_size,
linear.in_features,
linear.out_features,
linear.bias is not None,
linear.weight.device,
)
if init_only: # just prepare for loading sd
return awq_linear
assert zeros is None and scales is not None
tile = 16
maxq = 2**4 - 1
s = scales.t()
w = linear.weight.data.t()
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((-1, awq_linear.group_size, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape((awq_linear.group_size, -1))
s = s.reshape((1, -1))
w = torch.round(w / s).int()
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
if awq_linear.group_size != awq_linear.in_features:
w = w.reshape((awq_linear.group_size, -1, awq_linear.out_features))
w = w.permute(1, 0, 2)
w = w.reshape(
(awq_linear.in_features, awq_linear.out_features)
).contiguous()
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, awq_linear.out_features)).contiguous()
w = w.reshape(
(
awq_linear.in_features // tile,
tile,
awq_linear.out_features // tile,
tile,
)
)
w = w.permute((0, 2, 1, 3))
w = w.reshape((awq_linear.in_features // tile, awq_linear.out_features * tile))
res = w
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
for i in range(8):
q |= res[:, i::8] << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
awq_linear.qweight[:] = q.to(awq_linear.qweight.device)
awq_linear.scales[:] = s.to(awq_linear.qweight.device)
if awq_linear.bias is not None:
awq_linear.bias[:] = linear.bias.data.to(awq_linear.bias.device)
return awq_linear
def post_init(self):
self.register_buffer(
"workspace",
torch.zeros(
self.out_features // 128 * self.max_par,
dtype=torch.int32,
device=self.qweight.device,
),
persistent=False,
)
@torch.no_grad()
def forward(self, x):
assert hasattr(self, "workspace"), (
"module.post_init() must be called before module.forward(). "
"Use marlin_post_init() on the whole model."
)
assert MARLIN_INSTALLED, (
"Marlin kernels are not installed. "
"Please install AWQ compatible Marlin kernels from AutoAWQ_kernels."
)
out_shape = x.shape[:-1] + (self.out_features,)
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
x = x.view(-1, x.shape[-1])
out = torch.empty(
(x.shape[0], self.out_features),
dtype=torch.float16,
device=x.device,
)
marlin_cuda.mul(
x,
self.qweight,
out,
self.scales,
self.workspace,
-1, # thread_k
-1, # thread_n
-1, # sms
self.max_par,
)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
if self.bias is not None:
out.add_(self.bias)
return out.view(out_shape)
def extra_repr(self) -> str:
return (
"in_features={}, out_features={}, bias={}, w_bit={}, group_size={}".format(
self.in_features,
self.out_features,
self.bias is not None,
self.w_bit,
self.group_size,
)
)
def marlin_post_init(model):
for _, submodule in model.named_modules():
if isinstance(submodule, WQLinear_Marlin):
submodule.post_init()
return model
import torch
import inspect
import logging
import functools
import torch.nn as nn
from tqdm import tqdm
from typing import Dict, List, Optional
from collections import defaultdict
from awq.utils.calib_data import get_calib_dataset
from awq.quantize.scale import apply_scale, apply_clip
from awq.utils.utils import clear_memory, get_best_device
from awq.modules.linear import (
WQLinear_GEMM,
WQLinear_GEMV,
WQLinear_Marlin,
WQLinear_GEMVFast,
)
from awq.utils.module import (
append_str_prefix,
get_op_name,
get_named_linears,
set_op_by_name,
exclude_layers_to_not_quantize,
)
class AwqQuantizer:
def __init__(
self,
awq_model,
model,
tokenizer,
w_bit,
group_size,
zero_point,
version,
calib_data,
split,
text_column,
duo_scaling,
modules_to_not_convert=None,
export_compatible=False,
apply_clip=True,
) -> None:
self.awq_model = awq_model
self.model = model
self.tokenizer = tokenizer
self.w_bit = w_bit
self.group_size = group_size
self.zero_point = zero_point
self.version = version
self.calib_data = calib_data
self.split = split
self.text_column = text_column
self.duo_scaling = duo_scaling
self.export_compatible = export_compatible
self.apply_clip = apply_clip
self.modules_to_not_convert = (
modules_to_not_convert if modules_to_not_convert is not None else []
)
self.modules, self.module_kwargs, self.inps = self.init_quant()
def pseudo_quantize_tensor(self, w: torch.Tensor):
org_w_shape = w.shape
if self.group_size > 0:
assert org_w_shape[-1] % self.group_size == 0
w = w.reshape(-1, self.group_size)
assert w.dim() == 2
assert torch.isnan(w).sum() == 0
# zero point quantization
if self.zero_point:
max_val = w.amax(dim=1, keepdim=True)
min_val = w.amin(dim=1, keepdim=True)
max_int = 2**self.w_bit - 1
min_int = 0
scales = (max_val - min_val).clamp(min=1e-5) / max_int
zeros = (-torch.round(min_val / scales)).clamp_(min_int, max_int)
w = (
torch.clamp(torch.round(w / scales) + zeros, min_int, max_int) - zeros
) * scales
zeros = zeros.view(org_w_shape[0], -1)
else:
max_val = w.abs().amax(dim=1, keepdim=True)
max_val = max_val.clamp(min=1e-5)
max_int = 2 ** (self.w_bit - 1) - 1
min_int = -(2 ** (self.w_bit - 1))
scales = max_val / max_int
zeros = None
w = torch.clamp(torch.round(w / scales), min_int, max_int) * scales
assert torch.isnan(scales).sum() == 0
assert torch.isnan(w).sum() == 0
scales = scales.view(org_w_shape[0], -1)
w = w.reshape(org_w_shape)
return w, scales, zeros
def pseudo_dequantize_tensor(
self, w: nn.Linear, scales: torch.Tensor, zeros: Optional[torch.Tensor] = None
):
# get repeated count
repeat_count = w.weight.data.shape[-1] // scales.shape[-1]
scales = scales.repeat(1, repeat_count).reshape(w.weight.data.shape)
# dequantize
if self.zero_point:
zeros = zeros.repeat(1, repeat_count).reshape(w.weight.data.shape)
w = (w.weight.data - zeros) * scales
else:
w = w.weight.data * scales
return w
def quantize(self):
for i in tqdm(range(len(self.modules)), desc="AWQ"):
# Move module and inputs to correct device
common_device = next(self.modules[i].parameters()).device
if common_device is None or str(common_device) == "cpu":
if torch.cuda.is_available():
best_device = "cuda:" + str(i % torch.cuda.device_count())
else:
best_device = get_best_device()
self.modules[i] = self.modules[i].to(best_device)
common_device = next(self.modules[i].parameters()).device
if self.module_kwargs.get("position_ids") is not None:
self.module_kwargs["position_ids"] = self.module_kwargs[
"position_ids"
].to(common_device)
if self.module_kwargs.get("attention_mask") is not None:
self.module_kwargs["attention_mask"] = self.module_kwargs[
"attention_mask"
].to(common_device)
self.inps = self.inps.to(common_device)
# [STEP 1]: Get layer, extract linear modules, extract input features
named_linears = get_named_linears(self.modules[i])
# Filter out the linear layers we don't want to exclude
named_linears = exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
input_feat = self._get_input_feat(self.modules[i], named_linears)
clear_memory()
# [STEP 2]: Compute and apply scale list
module_config: List[Dict] = self.awq_model.get_layers_for_scaling(
self.modules[i], input_feat, self.module_kwargs
)
scales_list = [
self._search_best_scale(self.modules[i], **layer)
for layer in module_config
]
apply_scale(self.modules[i], scales_list, input_feat_dict=input_feat)
scales_list = append_str_prefix(
scales_list, get_op_name(self.model, self.modules[i]) + "."
)
# [STEP 3]: Compute and apply clipping list
if self.apply_clip:
clip_list = self._search_best_clip(
self.modules[i], named_linears, input_feat
)
apply_clip(self.modules[i], clip_list)
clip_list = append_str_prefix(
clip_list, get_op_name(self.model, self.modules[i]) + "."
)
# [STEP 4]: Quantize weights
if not self.export_compatible:
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def pack(self):
for i in tqdm(range(len(self.modules)), desc="Packing"):
named_linears = get_named_linears(self.modules[i])
named_linears = exclude_layers_to_not_quantize(
named_linears, self.modules_to_not_convert
)
self._apply_quant(self.modules[i], named_linears)
clear_memory()
def _apply_quant(self, module, named_linears: Dict[str, nn.Linear]):
for name, linear_layer in named_linears.items():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer = linear_layer.to(get_best_device()).half()
linear_layer.weight.data, scales, zeros = self.pseudo_quantize_tensor(
linear_layer.weight.data
)
if self.version == "gemm":
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
q_linear_module = WQLinear_GEMM
elif self.version == "gemv":
q_linear_module = WQLinear_GEMV
elif self.version == "marlin":
q_linear_module = WQLinear_Marlin
elif self.version == "gemv_fast":
q_linear_module = WQLinear_GEMVFast
else:
raise ValueError(f"Unknown version {self.version}")
q_linear = q_linear_module.from_linear(
linear=linear_layer,
w_bit=self.w_bit,
group_size=self.group_size,
init_only=False,
scales=scales,
zeros=zeros,
)
linear_layer.cpu()
q_linear.to(next(module.parameters()).device)
set_op_by_name(module, name, q_linear)
clear_memory()
@torch.no_grad()
def _search_best_scale(
self,
module,
prev_op,
layers: List[nn.Linear],
inp: torch.Tensor,
module2inspect=None,
kwargs={},
):
if module2inspect is None:
assert len(layers) == 1
module2inspect = layers[0]
if "use_cache" in kwargs:
kwargs.pop("use_cache")
# Put x on the right device
inp = inp.to(next(module2inspect.parameters()).device)
# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight = torch.cat([_m.weight for _m in layers], dim=0)
org_shape = weight.shape
# The weights are reshaped to be organised by quantization group
weight = weight.view(-1, self.group_size)
# Calculates the relative magnitude of the weights within each of the quantization groups,
# and rescales each group individually so that each group has weights on a 0-1 scale.
w_scale = weight.abs() / weight.abs().amax(dim=1, keepdim=True)
# Resizes the rescaled weight matrix back up to its original dimensions
w_scale = w_scale.view(org_shape)
# Gets the average rescaled magnitude for each output channel
w_mean = w_scale.mean(0)
clear_memory(weight)
# [STEP 2]: Compute per-channel mean of the input activation
x_mean = inp.abs().view(-1, inp.shape[-1]).mean(0)
# [STEP 3]: Compute output of module
with torch.no_grad():
module_kwargs = self._sanitize_kwargs(kwargs, module2inspect)
fp16_output = module2inspect(inp, **module_kwargs)
if isinstance(fp16_output, tuple):
fp16_output = fp16_output[0]
# [STEP 4]: Compute loss
best_scales = self._compute_best_scale(
inp, w_mean, x_mean, module2inspect, layers, fp16_output, module_kwargs
)
return (
get_op_name(module, prev_op),
tuple([get_op_name(module, m) for m in layers]),
best_scales,
)
def _compute_best_scale(
self,
x,
w_mean,
x_mean,
module2inspect,
linears2scale: List[nn.Linear],
fp16_output,
kwargs={},
):
"""
Compute loss and select best scales
L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Q: weight quantization function | pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X
W: original weights in FP16 | layer
s: per channel scaling factor | s^-1 * X
"""
n_grid = 20
history = []
best_ratio = -1
best_scales = None
best_error = float("inf")
org_sd = {k: v.cpu() for k, v in module2inspect.state_dict().items()}
device = x.device
x_mean = x_mean.view(-1).to(device)
w_mean = w_mean.view(-1).to(device)
for ratio in range(n_grid):
# create new scales
ratio = ratio / n_grid
# NOTE: s^-1 * x is fused here, according to paper
if self.duo_scaling:
scales = (x_mean.pow(ratio) / w_mean.pow(1 - ratio)).clamp(min=1e-4)
else:
scales = x_mean.pow(ratio).clamp(min=1e-4).view(-1)
scales = scales / (scales.max() * scales.min()).sqrt()
scales_view = scales.view(1, -1).to(device)
# Q(W * s)
for fc in linears2scale:
fc.weight.mul_(scales_view)
fc.weight.data = (
self.pseudo_quantize_tensor(fc.weight.data)[0] / scales_view
)
# W * X
int_w_output = module2inspect(x, **kwargs)
if isinstance(int_w_output, tuple):
int_w_output = int_w_output[0]
# compute mean squared error (L2 norm)
loss = (
(fp16_output - int_w_output).float().pow(2).mean().item()
) # NOTE: float prevents overflow
history.append(loss)
if loss < best_error:
best_error = loss
best_ratio = ratio
best_scales = scales.clone()
module2inspect.load_state_dict(org_sd)
if best_ratio == -1:
logging.debug(history)
raise Exception
assert torch.isnan(best_scales).sum() == 0, best_scales
return best_scales.detach().cpu()
@torch.no_grad()
def _search_best_clip(self, layer, named_linears, input_feat):
clip_list = []
avoid_clipping = ["q_", "k_", "query", "key", "Wqkv"]
for name in named_linears:
# due to qk bmm, it is hard to clip precisely
if any([_ in name for _ in avoid_clipping]):
continue
named_linears[name].to(get_best_device())
max_val = self._compute_best_clip(
named_linears[name].weight, input_feat[name]
)
clip_list.append((name, max_val))
named_linears[name].cpu()
return clip_list
@torch.no_grad()
def _compute_best_clip(
self,
w: torch.Tensor,
input_feat: torch.Tensor,
n_grid=20,
max_shrink=0.5,
n_sample_token=512,
):
assert w.dim() == 2
org_w_shape = w.shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size = self.group_size if self.group_size > 0 else org_w_shape[1]
input_feat = input_feat.view(-1, input_feat.shape[-1])
input_feat = input_feat.reshape(1, input_feat.shape[0], -1, group_size)
input_feat = input_feat[:, 0 :: input_feat.shape[1] // n_sample_token]
w = w.reshape(org_w_shape[0], 1, -1, group_size)
oc_batch_size = 256 if org_w_shape[0] % 256 == 0 else 64 # prevent OOM
assert org_w_shape[0] % oc_batch_size == 0
w_all = w
best_max_val_all = []
for i_b in range(org_w_shape[0] // oc_batch_size):
w = w_all[i_b * oc_batch_size : (i_b + 1) * oc_batch_size]
org_max_val = w.abs().amax(dim=-1, keepdim=True) # co, 1, n_group, 1
best_max_val = org_max_val.clone()
min_errs = torch.ones_like(org_max_val) * 1e9
input_feat = input_feat.to(w.device)
org_out = (input_feat * w).sum(dim=-1) # co, n_token, n_group
for i_s in range(int(max_shrink * n_grid)):
max_val = org_max_val * (1 - i_s / n_grid)
min_val = -max_val
cur_w = torch.clamp(w, min_val, max_val)
q_w = self.pseudo_quantize_tensor(cur_w)[0]
cur_out = (input_feat * q_w).sum(dim=-1)
# co, 1, n_group, 1
err = (cur_out - org_out).pow(2).mean(dim=1).view(min_errs.shape)
del cur_w
del cur_out
cur_best_idx = err < min_errs
min_errs[cur_best_idx] = err[cur_best_idx]
best_max_val[cur_best_idx] = max_val[cur_best_idx]
best_max_val_all.append(best_max_val)
best_max_val = torch.cat(best_max_val_all, dim=0)
clear_memory(input_feat)
clear_memory(org_out)
return best_max_val.squeeze(1)
def init_quant(self, n_samples=128, seqlen=512):
modules = self.awq_model.get_model_layers(self.model)
samples = get_calib_dataset(
data=self.calib_data,
tokenizer=self.tokenizer,
n_samples=n_samples,
block_size=seqlen,
split=self.split,
text_column=self.text_column,
)
samples = torch.cat(samples, dim=0)
inps = []
layer_kwargs = {}
best_device = get_best_device()
modules[0] = modules[0].to(best_device)
self.awq_model.move_embed(self.model, best_device)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class Catcher(nn.Module):
def __init__(self, module):
super().__init__()
self.module = module
def forward(self, *args, **kwargs):
# assume first input to forward is hidden states
if len(args) > 0:
hidden_states = args[0]
del args
else:
first_key = list(kwargs.keys())[0]
hidden_states = kwargs.pop(first_key)
inps.append(hidden_states)
layer_kwargs.update(kwargs)
raise ValueError # early exit to break later inference
# patch layer 0 to catch input and kwargs
modules[0] = Catcher(modules[0])
try:
self.model(samples.to(next(self.model.parameters()).device))
except ValueError: # work with early exit
pass
modules[0] = modules[0].module # restore
# Update the layer kwargs with `prepare_inputs_for_generation` method
# that takes care of everything to avoid unexpected errors.
layer_kwargs = self.model.prepare_inputs_for_generation(samples, **layer_kwargs)
# Pop the input_ids as they are not needed at all.
layer_kwargs.pop("input_ids")
del samples
inps = inps[0]
modules[0] = modules[0].cpu()
self.awq_model.move_embed(self.model, "cpu")
clear_memory()
if layer_kwargs.get("attention_mask") is not None:
layer_kwargs["attention_mask"] = layer_kwargs["attention_mask"].to(
best_device
)
return modules, layer_kwargs, inps
def _get_input_feat(self, layer, named_linears):
# firstly, get input features of all linear layers
def cache_input_hook(m, x, y, name, feat_dict):
x = x[0]
x = x.detach().cpu()
feat_dict[name].append(x)
input_feat = defaultdict(list)
handles = []
# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if self.awq_model.model_type == "mixtral":
named_linears = {
**named_linears,
"block_sparse_moe": layer.block_sparse_moe,
}
for name in named_linears:
handles.append(
named_linears[name].register_forward_hook(
functools.partial(cache_input_hook, name=name, feat_dict=input_feat)
)
)
self.inps = self.inps.to(next(layer.parameters()).device) # in case multi-gpu
# get output as next layer's input
# Sanitize the kwargs in case we use transformers version that contains
# kwargs that are not handled by the module.
# Useful for trust_remote_code models.
module_kwargs = self._sanitize_kwargs(self.module_kwargs, layer)
self.inps = layer(self.inps, **module_kwargs)[0]
for h in handles:
h.remove()
# now solve for scaling and clipping
input_feat = {k: torch.cat(v, dim=0) for k, v in input_feat.items()}
return input_feat
def _sanitize_kwargs(self, inputs_kwargs, module):
"""
Remove the arguments that are not supported in the module's
forward pass to avoid breaking behaviour between different versions
of transformers.
Args:
inputs_kwargs (`dict`):
The input dictionary to pass to the model layer
module (`torch.nn.Module`):
Target module to quantize.
"""
module_signature = inspect.signature(module.forward).parameters
sanitized_kwargs = {}
for k, v in inputs_kwargs.items():
if k in module_signature:
sanitized_kwargs[k] = v
return sanitized_kwargs
import torch
import torch.nn as nn
from typing import Tuple, List
from awq.utils.utils import get_best_device
from awq.modules.act import ScaledActivation
from awq.utils.module import get_op_by_name, set_op_by_name
from transformers.models.bloom.modeling_bloom import BloomGelu
from transformers.models.llama.modeling_llama import LlamaRMSNorm
from transformers.models.gemma.modeling_gemma import GemmaRMSNorm
from transformers.activations import NewGELUActivation, PytorchGELUTanh, GELUActivation
allowed_norms = [nn.LayerNorm, LlamaRMSNorm, GemmaRMSNorm]
allowed_act_fns = [
nn.GELU,
BloomGelu,
NewGELUActivation,
PytorchGELUTanh,
GELUActivation,
]
@torch.no_grad()
def apply_clip(module, clip_list: Tuple[str, torch.Tensor]):
for name, max_val in clip_list:
layer: nn.Linear = get_op_by_name(module, name)
layer.to(get_best_device())
max_val = max_val.to(layer.weight.device)
org_shape = layer.weight.shape
layer.weight.data = layer.weight.data.reshape(*max_val.shape[:2], -1)
layer.weight.data = torch.clamp(layer.weight.data, -max_val, max_val)
layer.weight.data = layer.weight.data.reshape(org_shape)
layer.cpu()
def apply_scale(module, scales_list, input_feat_dict=None):
for prev_op_name, layer_names, scales in scales_list:
prev_op = get_op_by_name(module, prev_op_name)
layers = [get_op_by_name(module, name) for name in layer_names]
best_device = get_best_device()
prev_op.to(best_device)
for layer in layers:
layer.to(best_device)
scales.to(best_device)
if (
isinstance(prev_op, nn.Linear)
and type(layers) == list
and isinstance(layers[0], nn.Linear)
):
scale_fc_fcs(prev_op, layers, scales)
elif isinstance(prev_op, nn.Linear):
assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales)
elif (
any(isinstance(prev_op, t) for t in allowed_norms)
or "rmsnorm" in str(prev_op.__class__).lower()
):
scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op, t) for t in allowed_act_fns):
new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales)
else:
raise NotImplementedError(f"prev_op {type(prev_op)} not supported yet!")
# apply the scaling to input feat if given; prepare it for clipping
if input_feat_dict is not None:
for layer_name in layer_names:
# Skip the modules that are not quantized
if layer_name in input_feat_dict:
inp = input_feat_dict[layer_name]
inp.div_(scales.view(1, -1).to(inp.device))
prev_op.cpu()
for layer in layers:
layer.cpu()
scales.cpu()
@torch.no_grad()
def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(ln.weight.device)
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if isinstance(ln, GemmaRMSNorm):
ln.weight += 1
ln.weight.div_(scales)
ln.weight -= 1
else:
ln.weight.div_(scales)
if hasattr(ln, "bias") and ln.bias is not None:
ln.bias.div_(scales)
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in ln.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
scales = scales.to(fc1.weight.device)
fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
fc2.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for p in fc2.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fcs(fc1: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list):
fcs = [fcs]
scales = scales.to(fc1.weight.device)
fc1.weight[-scales.size(0) :].div_(scales.view(-1, 1))
if fc1.bias is not None:
fc1.bias.div_(scales.view(-1))
for fc in fcs:
fc.weight.mul_(scales.view(1, -1))
for p in fc1.parameters():
assert torch.isnan(p).sum() == 0
for fc in fcs:
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor):
assert any(isinstance(gelu, t) for t in allowed_act_fns)
assert isinstance(fc, nn.Linear)
fc.weight.mul_(scales.view(1, -1).to(fc.weight.device))
for p in fc.parameters():
assert torch.isnan(p).sum() == 0
import torch
import logging
from typing import List, Union
from datasets import load_dataset
def get_calib_dataset(
data: Union[str, List[str], List[List[int]]] = "pileval",
tokenizer=None,
n_samples=512,
block_size=512,
split="train",
text_column="text",
):
if isinstance(data, str):
if data == "pileval":
dataset = load_dataset("mit-han-lab/pile-val-backup", split="validation")
else:
dataset = load_dataset(data, split=split)
dataset = dataset.shuffle(seed=42)
elif isinstance(data, list):
if isinstance(data[0], str):
dataset = [{text_column: text} for text in data]
elif isinstance(data[0][0], int):
dataset = data
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
else:
raise NotImplementedError(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
samples = []
n_run = 0
for data in dataset:
if isinstance(data, list):
line_encoded = data
else:
line = data[text_column]
line = line.strip()
line_encoded = tokenizer.encode(line)
if len(line_encoded) > 512:
continue
sample = torch.tensor([line_encoded])
if sample.numel() == 0:
continue
samples.append(sample)
n_run += 1
if n_run == n_samples:
break
# now concatenate all samples and split according to block size
cat_samples = torch.cat(samples, dim=1)
n_split = cat_samples.shape[1] // block_size
logging.debug(f" * Split into {n_split} blocks")
return [
cat_samples[:, i * block_size : (i + 1) * block_size] for i in range(n_split)
]
import torch
from awq.modules.linear import (
WQLinear_GEMM,
WQLinear_GEMV,
WQLinear_Marlin,
WQLinear_Exllama,
WQLinear_ExllamaV2,
WQLinear_GEMVFast,
)
def prepare_correct_devices(next_layer, hidden_states, mask):
hidden_states = hidden_states.to(next_layer.device)
if mask is not None:
mask = mask.to(next_layer.device)
return hidden_states, mask
def prepare_cache(blocks, seqlen: int) -> int:
for block in blocks:
start_pos = block.attn.start_pos
will_cache_be_exceeded = start_pos + seqlen > block.attn.max_seq_len
# Reset and avoid retaining state when processing context
if seqlen > 1 and (will_cache_be_exceeded or start_pos > 0):
block.attn.start_pos = block.attn.cache.roll_kv_n_steps(
start_pos, n=start_pos
)
# Slowly roll out old tokens without performance hit if exceeded during decoding
elif seqlen == 1 and will_cache_be_exceeded:
block.attn.start_pos = block.attn.cache.roll_kv_n_steps(start_pos, n=100)
def prepare_input_ids(input_ids: torch.Tensor, last_forward_num_tokens: int):
# NOTE: from transformers 4.35.0, input_ids includes full context during decoding
num_input_tokens = input_ids.shape[-1]
num_new_tokens = num_input_tokens
if num_input_tokens != 1:
num_new_tokens = num_input_tokens - last_forward_num_tokens
# after context is processed, slice to latest token
if num_new_tokens == 1:
input_ids = input_ids[:, -1:]
return input_ids, last_forward_num_tokens + num_new_tokens
def prepare_attention_mask(seqlen, start_pos, device, type_as: torch.Tensor):
mask = None
if seqlen > 1:
mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=device)
mask = torch.triu(mask, diagonal=start_pos + 1).type_as(type_as)
return mask
def fuse_qkv(module, q_proj, k_proj, v_proj):
bias = (
torch.cat([q_proj.bias, k_proj.bias, v_proj.bias], dim=0)
if q_proj.bias is not None
else None
)
if isinstance(q_proj, WQLinear_GEMV):
q_linear = WQLinear_GEMV
elif isinstance(q_proj, WQLinear_GEMM):
q_linear = WQLinear_GEMM
elif isinstance(q_proj, WQLinear_Exllama):
q_linear = WQLinear_Exllama
elif isinstance(q_proj, WQLinear_ExllamaV2):
q_linear = WQLinear_ExllamaV2
elif isinstance(q_proj, WQLinear_Marlin):
q_linear = WQLinear_Marlin
elif isinstance(q_proj, WQLinear_GEMVFast):
q_linear = WQLinear_GEMVFast
qkv_layer = q_linear(
q_proj.w_bit,
q_proj.group_size,
q_proj.in_features,
q_proj.out_features + k_proj.out_features + v_proj.out_features,
q_proj.bias is not None,
next(iter(module.state_dict().values())).device,
)
if isinstance(q_proj, WQLinear_GEMV):
qkv_layer.qweight = torch.cat(
[q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0
)
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=0
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=0
)
qkv_layer.split_k_iters = q_proj.split_k_iters
elif isinstance(q_proj, WQLinear_GEMM):
qkv_layer.qweight = torch.cat(
[q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
)
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
elif isinstance(q_proj, WQLinear_Exllama):
qkv_layer.qweight = torch.cat(
[q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
)
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
elif isinstance(q_proj, WQLinear_ExllamaV2):
qkv_layer.qweight = torch.cat(
[q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
)
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
elif isinstance(q_proj, WQLinear_Marlin):
qkv_layer.qweight = torch.cat(
[q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=1
)
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
)
# workspace is created in post_init
elif isinstance(q_proj, WQLinear_GEMVFast):
qkv_layer.qweight = torch.cat(
[q_proj.qweight, k_proj.qweight, v_proj.qweight], dim=0
)
qkv_layer.qzeros = torch.cat(
[q_proj.qzeros, k_proj.qzeros, v_proj.qzeros], dim=1
).contiguous()
qkv_layer.scales = torch.cat(
[q_proj.scales, k_proj.scales, v_proj.scales], dim=1
).contiguous()
qkv_layer.split_k_iters = q_proj.split_k_iters
qkv_layer.bias = bias
for layer in [q_proj, k_proj, v_proj]:
del (layer.qweight, layer.qzeros, layer.scales)
return qkv_layer
def fuse_linears(linears, device, dim=1, operation=torch.cat):
total_out_features = sum([layer.out_features for layer in linears])
fused = WQLinear_GEMM(
linears[0].w_bit,
linears[0].group_size,
linears[0].in_features,
total_out_features,
bias=None,
dev=device,
)
fused.qweight = operation([layer.qweight for layer in linears], dim=dim)
fused.qzeros = operation([layer.qzeros for layer in linears], dim=dim)
fused.scales = operation([layer.scales for layer in linears], dim=dim)
for layer in linears:
del (layer.qweight, layer.qzeros, layer.scales, layer)
return fused
def get_attention_shapes(
attention_shapes, max_seq_len, cache_batch_size, n_heads, n_kv_heads, head_dim
):
if attention_shapes is not None:
attention_shapes = attention_shapes
elif n_kv_heads == 0:
attention_shapes = {
# following fastertransformer definition
"cache_v": (
cache_batch_size,
n_heads,
max_seq_len,
head_dim,
),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (
cache_batch_size,
n_heads,
head_dim // 8,
max_seq_len,
8,
),
"xqkv_view": (-1, n_heads, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0],
"xk_slice": lambda xqkv: xqkv[:, :, 1],
"xv_slice": lambda xqkv: xqkv[:, :, 2],
"xq_view": (n_heads, head_dim),
"xk_view": (n_heads, head_dim),
"xv_view": (n_heads, head_dim),
"xk_reshape": (n_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_heads, head_dim),
"single_xv_view": (n_heads, head_dim),
}
else:
attention_shapes = {
# following fastertransformer definition
"cache_v": (
cache_batch_size,
n_kv_heads,
max_seq_len,
head_dim,
),
# 8: pack 8 fp16 in FT, if fp32 then use 4
"cache_k": (
cache_batch_size,
n_kv_heads,
head_dim // 8,
max_seq_len,
8,
),
"xqkv_view": (n_heads + n_kv_heads * 2, head_dim),
"xq_slice": lambda xqkv: xqkv[:, :, 0:n_heads],
"xk_slice": lambda xqkv: xqkv[:, :, n_heads : (n_heads + n_kv_heads)],
"xv_slice": lambda xqkv: xqkv[:, :, -n_kv_heads:],
"xq_view": (n_heads, head_dim),
"xk_view": (n_kv_heads, head_dim),
"xv_view": (n_kv_heads, head_dim),
"xk_reshape": (n_kv_heads, head_dim // 8, 8),
"single_xq_view": (n_heads, head_dim),
"single_xk_view": (n_kv_heads, head_dim),
"single_xv_view": (n_kv_heads, head_dim),
}
return attention_shapes
import torch.nn as nn
def get_named_linears(module):
return {name: m for name, m in module.named_modules() if isinstance(m, nn.Linear)}
def get_op_by_name(module, op_name):
# get the op by its name relative to the module
for name, m in module.named_modules():
if name == op_name:
return m
raise ValueError(f"Cannot find op {op_name} in module {module}")
def set_op_by_name(layer, name, new_module):
levels = name.split(".")
if len(levels) > 1:
mod_ = layer
for l_idx in range(len(levels) - 1):
if levels[l_idx].isdigit():
mod_ = mod_[int(levels[l_idx])]
else:
mod_ = getattr(mod_, levels[l_idx])
setattr(mod_, levels[-1], new_module)
else:
setattr(layer, name, new_module)
def get_op_name(module, op):
# get the name of the op relative to the module
for name, m in module.named_modules():
if m is op:
return name
raise ValueError(f"Cannot find op {op} in module {module}")
def append_str_prefix(x, prefix):
if isinstance(x, str):
return prefix + x
elif isinstance(x, tuple):
return tuple([append_str_prefix(y, prefix) for y in x])
elif isinstance(x, list):
return [append_str_prefix(y, prefix) for y in x]
else:
return x
def exclude_layers_to_not_quantize(linear_layers, modules_to_not_convert):
if modules_to_not_convert is None:
return linear_layers
filtered_layers = {}
for name, linear_layer in linear_layers.items():
if not any(key in name for key in modules_to_not_convert):
filtered_layers[name] = linear_layer
return filtered_layers
import torch
AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def unpack_awq(qweight: torch.Tensor, qzeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=qzeros.device)
# unpacking columnwise
iweights = torch.bitwise_right_shift(qweight[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
iweights = iweights.view(iweights.shape[0], -1)
# unpacking columnwise
izeros = torch.bitwise_right_shift(qzeros[:, :, None], shifts[None, None, :]).to(
torch.int8 # smallest dtype available
)
izeros = izeros.view(izeros.shape[0], -1)
return iweights, izeros
def reverse_awq_order(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
reverse_order_tensor = torch.arange(
izeros.shape[-1],
dtype=torch.int32,
device=izeros.device,
)
reverse_order_tensor = reverse_order_tensor.view(-1, 32 // bits)
reverse_order_tensor = reverse_order_tensor[:, AWQ_REVERSE_ORDER]
reverse_order_tensor = reverse_order_tensor.view(-1)
izeros = izeros[:, reverse_order_tensor]
iweights = iweights[:, reverse_order_tensor]
return iweights, izeros
def pack_exllama(iweights: torch.Tensor, izeros: torch.Tensor, bits: int):
shifts = torch.arange(0, 32, bits, device=iweights.device)
# packing rowwise
iweights = iweights.view(iweights.shape[0] // (32 // bits), 32 // bits, -1)
qweight = (
torch.bitwise_left_shift(iweights, shifts[None, :, None])
.sum(dim=1)
.to(torch.int32)
)
# packing columnwise
izeros = izeros.view(-1, izeros.shape[1] // (32 // bits), 32 // bits)
qzeros = (
torch.bitwise_left_shift(izeros, shifts[None, None, :])
.sum(dim=-1)
.to(torch.int32)
)
return qweight, qzeros
def unpack_reorder_pack(qweight, qzeros, bits):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
# Subtract 1 from the izeros tensor (exllama adds 1 during inference)
# We can remove it if we remove the +1 in the exllama code
izeros = izeros - 1
# Pack the qweight and qzeros tensors
qweight, qzeros = pack_exllama(iweight, izeros, bits)
return qweight, qzeros
def dequantize_gemm(qweight, qzeros, scales, bits, group_size):
# Unpack the qweight and qzeros tensors
iweight, izeros = unpack_awq(qweight, qzeros, bits)
# Reverse the order of the iweight and izeros tensors
iweight, izeros = reverse_awq_order(iweight, izeros, bits)
# overflow checks
iweight = torch.bitwise_and(iweight, (2**bits) - 1)
izeros = torch.bitwise_and(izeros, (2**bits) - 1)
# fp16 weights
scales = scales.repeat_interleave(group_size, dim=0)
izeros = izeros.repeat_interleave(group_size, dim=0)
iweight = (iweight - izeros) * scales
return iweight
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment