Commit da900c3b authored by yangql's avatar yangql
Browse files

Initial commit

parents
import math
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import transformers
try:
import habana_frameworks.torch.core as htcore
convert_from_uint4 = torch.ops.hpu.convert_from_uint4
except Exception as e:
hpu_import_exception = e
def error_raiser_hpu(*args, **kwargs):
raise ValueError(
f"Trying to use HPU, but could not import the HPU framework with the following error: {hpu_import_exception}"
)
convert_from_uint4 = error_raiser_hpu
logger = getLogger(__name__)
def pack_tensor(input, bits = 4):
normal = input.to(torch.int32)
q = torch.zeros((normal.shape[0], normal.shape[1] // 32 * bits), dtype=torch.int32)
i = 0
col = 0
while col < q.shape[1]:
for j in range(i, i + (32 // bits)):
q[:, col] |= normal[:, j] << (bits * (j - i))
i += 32 // bits
col += 1
q = q.to(torch.int32)
return q
class QuantLinear(nn.Module):
QUANT_TYPE = "hpu"
def __init__(
self,
bits,
group_size,
infeatures,
outfeatures,
bias,
use_cuda_fp16=True,
kernel_switch_threshold=128,
trainable=False,
weight_dtype=torch.float16,
):
logger.debug(f"qlinear_hpu QuantLinear::__init__ {bits=}, {group_size=}, {infeatures=}, {outfeatures=}, {bias=}, {use_cuda_fp16=}, {kernel_switch_threshold=}, {trainable=}, {weight_dtype=}")
super().__init__()
if bits != 4:
raise NotImplementedError("Only 4 bits are supported.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2**self.bits - 1
self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=weight_dtype,
),
)
self.register_buffer(
"g_idx",
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
)
if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=weight_dtype))
else:
self.bias = None
self.half_indim = self.infeatures // 2
self.wf = torch.tensor(list(range(0, 32, self.bits)), dtype=torch.int32).unsqueeze(0)
def _preprocessing(self):
self.qweight = self.qweight.cpu()
weight = self.unpack_weight_from_cuda_old_format()
new_qweight = pack_tensor(weight)
self.qweight = new_qweight.to('hpu')
# TODO: Support group indexing and remove the check
columns = self.qweight.shape[0]
g_idx_trivial = [i // self.group_size for i in range(columns)]
g_idx_trivial = torch.tensor(g_idx_trivial, dtype=torch.int32)
assert torch.equal(self.g_idx, g_idx_trivial), "Non-trivial tensor g_idx is not supported"
zeros = self.unpack_zeros_from_cuda_old_format().cpu()
new_qzeros = pack_tensor(zeros)
self.qzeros = new_qzeros.to('hpu')
def post_init(self):
self._preprocessing()
def pack(self, linear, scales, zeros, g_idx):
#TODO: implement
raise NotImplementedError("QuantLinear HPU currently doesn't support packing")
def set_packed(self, qlinear_cls):
self.qweight = qlinear_cls.qweight
self.qzeros = qlinear_cls.qzeros
self.scales = qlinear_cls.scales
self.bias = qlinear_cls.bias
def forward(self, x):
x_dtype = x.dtype
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape(-1, x.shape[-1])
scales = self.scales
qweight = self.qweight
zeros = self.qzeros
weight = convert_from_uint4(qweight, scales, zeros, x_dtype)
out = torch.matmul(x, weight)
out = out.reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
def unpack_zeros_from_cuda_old_format(self):
zeros = torch.bitwise_right_shift(
torch.unsqueeze(self.qzeros, 2).expand(-1, -1, 32 // self.bits),
self.wf.unsqueeze(0),
).to(torch.int16 if self.bits == 8 else torch.int8)
zeros = zeros + 1
zeros = torch.bitwise_and(
zeros, (2**self.bits) - 1
).to(self.scales.dtype) # NOTE: It appears that casting here after the `zeros = zeros + 1` is important.
zeros = zeros.reshape(-1, zeros.shape[1] * zeros.shape[2])
return zeros
def unpack_weight_from_cuda_old_format(self):
weight = torch.bitwise_right_shift(
torch.unsqueeze(self.qweight, 1).expand(-1, 32 // self.bits, -1),
self.wf.unsqueeze(-1),
).to(torch.int16 if self.bits == 8 else torch.int8)
weight = torch.bitwise_and(weight, (2**self.bits) - 1)
weight = weight.reshape((weight.shape[0]*weight.shape[1], weight.shape[2]))
return weight
__all__ = ["QuantLinear"]
# Copyright (C) Marlin.2024 Elias Frantar (elias.frantar@ist.ac.at)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
logger = getLogger(__name__)
try:
import autogptq_marlin_cuda
except ImportError as e:
marlin_import_exception = e
def error_raiser_marlin(*args, **kwargs):
raise ValueError(
f"Trying to use the marlin backend, but could not import the C++/CUDA dependencies with the following error: {marlin_import_exception}"
)
autogptq_marlin_cuda = error_raiser_marlin
def mul(A, B, C, s, workspace, thread_k=-1, thread_n=-1, sms=-1, max_par=16):
"""Marlin FP16xINT4 multiply; can be used within `torch.compile`.
@A: `torch.half` input matrix of shape `(m, k)` in standard row-major layout
@B: `torch.int` weight matrix of original shape `(k, n)` in Marlin format; see `Layer.pack()`
@C: `torch.half` out matrix of shape `(m, n)` in standard row-major layout
@s: `torch.half` scales of shape `(m / group_size, n)`
@workspace: `torch.int` tensor with at least `n / 128 * max_par` entries that are all zero
@thread_k: `k` size of a thread_tile in `B` (can usually be left as auto -1)
@thread_n: `n` size of a thread_tile in `B` (can usually be left as auto -1)
@sms: number of SMs to use for the kernel (can usually be left as auto -1)
@max_par: maximum number of batch 64 problems to solve in parallel for large input sizes
"""
autogptq_marlin_cuda.mul(A, B, C, s, workspace, thread_k, thread_n, sms, max_par)
# Precompute permutations for Marlin weight and scale shuffling
def _get_perms():
perm = []
for i in range(32):
perm1 = []
col = i // 4
for block in [0, 1]:
for row in [
2 * (i % 4),
2 * (i % 4) + 1,
2 * (i % 4 + 4),
2 * (i % 4 + 4) + 1,
]:
perm1.append(16 * row + col + 8 * block)
for j in range(4):
perm.extend([p + 256 * j for p in perm1])
perm = np.array(perm)
interleave = np.array([0, 2, 4, 6, 1, 3, 5, 7])
perm = perm.reshape((-1, 8))[:, interleave].ravel()
perm = torch.from_numpy(perm)
scale_perm = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single = []
for i in range(4):
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return perm, scale_perm, scale_perm_single
_perm, _scale_perm, _scale_perm_single = _get_perms()
class QuantLinear(nn.Module):
QUANT_TYPE = "marlin"
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
super().__init__()
if torch.version.hip:
raise ValueError("Can not use Marlin int4*fp16 kernel with AMD ROCm version of PyTorch as the kernel is not compatible. Please do not use `use_marlin=True` when using ROCm devices.")
if not torch.cuda.get_device_capability()[0] >= 8:
raise ValueError(f'Can not use Marlin int4*fp16 kernel with a device of compute capability {torch.cuda.get_device_capability()}, the minimum compute capability is 8.0 for Marlin kernel. Please do not use `use_marlin=True`, or please upgrade your GPU ("The more you buy, the more you save." - Taiwanese proverb).')
if infeatures % 128 != 0 or outfeatures % 256 != 0:
raise ValueError("`infeatures` must be divisible by 128 and `outfeatures` by 256.")
if bits not in [4]:
raise NotImplementedError("Only 4 bits are supported.")
if group_size not in [-1, 128] and group_size != infeatures:
raise ValueError("Only group_size -1 and 128 are supported.")
if infeatures % group_size != 0:
raise ValueError("`infeatures` must be divisible by `group_size`.")
if trainable:
raise NotImplementedError("Marlin does not support train.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.group_size = group_size if group_size != -1 else infeatures
self.register_buffer(
"B",
torch.empty((self.infeatures // 16, self.outfeatures * 16 // 8), dtype=torch.int),
)
self.register_buffer(
"s",
torch.empty((self.infeatures // group_size, self.outfeatures), dtype=torch.half),
)
# 128 is currently the minimum `tile_n`, hence it gives the maximum workspace size; 16 is the default `max_par`
self.register_buffer(
"workspace",
torch.zeros(self.outfeatures // 128 * 16, dtype=torch.int),
persistent=False,
)
if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.half))
else:
self.bias = None
def post_init(self):
pass
def pack(self, linear, scales):
"""Pack a fake-quantized linear layer into this actual Marlin representation.
@linear: fake-quantized `torch.nn.Linear` layer to convert (must be of type `torch.half`)
@scales: corresponding quantization scales of shape `(infeatures, groups)`
"""
if linear.weight.dtype != torch.half:
raise ValueError("Only `torch.half` weights are supported.")
tile = 16
maxq = 2**4 - 1
s = scales.t()
w = linear.weight.data.t()
if self.group_size != self.infeatures:
w = w.reshape((-1, self.group_size, self.outfeatures))
w = w.permute(1, 0, 2)
w = w.reshape((self.group_size, -1))
s = s.reshape((1, -1))
w = torch.round(w / s).int()
w += (maxq + 1) // 2
w = torch.clamp(w, 0, maxq)
if self.group_size != self.infeatures:
w = w.reshape((self.group_size, -1, self.outfeatures))
w = w.permute(1, 0, 2)
w = w.reshape((self.infeatures, self.outfeatures)).contiguous()
s = s.reshape((-1, len(_scale_perm)))[:, _scale_perm]
else:
s = s.reshape((-1, len(_scale_perm_single)))[:, _scale_perm_single]
s = s.reshape((-1, self.outfeatures)).contiguous()
w = w.reshape((self.infeatures // tile, tile, self.outfeatures // tile, tile))
w = w.permute((0, 2, 1, 3))
w = w.reshape((self.infeatures // tile, self.outfeatures * tile))
res = w
res = res.reshape((-1, _perm.numel()))[:, _perm].reshape(res.shape)
q = np.zeros((res.shape[0], res.shape[1] // 8), dtype=np.uint32)
res = res.cpu().numpy().astype(np.uint32)
for i in range(8):
q |= res[:, i::8] << 4 * i
q = torch.from_numpy(q.astype(np.int32)).to(w.device)
self.B[:, :] = q.to(self.B.device)
self.s[:, :] = s.to(self.s.device)
if linear.bias is not None:
if self.bias is not None:
self.bias[:] = linear.bias.data.to(self.bias.device)
else:
self.bias = linear.bias.clone()
def forward(self, A):
A = A.half()
C = torch.empty(A.shape[:-1] + (self.s.shape[1],), dtype=A.dtype, device=A.device)
mul(
A.view((-1, A.shape[-1])),
self.B,
C.view((-1, C.shape[-1])),
self.s,
self.workspace,
)
C = C + self.bias if self.bias is not None else C
return C
# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
def unpack_4bit_to_32bit_signed(qweight, qzeros):
# Unpack 4-bit values and interpret them as signed integers
unpacked_weights = torch.zeros(
(qweight.shape[0] * 8, qweight.shape[1]),
dtype=torch.int8,
device=qweight.device,
requires_grad=False,
)
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * 8),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for row in range(unpacked_weights.shape[0]):
i = row % 8
unpacked_weights[row, :] = (qweight[row // 8, :] >> (4 * i)) & 0xF
for col in range(unpacked_zeros.shape[1]):
i = col % 8
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
return unpacked_weights, unpacked_zeros + 1
def unpack_qzeros(qzeros):
unpacked_zeros = torch.zeros(
(qzeros.shape[0], qzeros.shape[1] * 8),
dtype=torch.int8,
device=qzeros.device,
requires_grad=False,
)
for col in range(unpacked_zeros.shape[1]):
i = col % 8
unpacked_zeros[:, col] = (qzeros[:, col // 8] >> (4 * i)) & 0xF
return unpacked_zeros + 1
# Copied from https://github.com/IST-DASLab/marlin/pull/1
@torch.no_grad()
def dequantize_weight(layer):
qweight, qzeros, scales = layer.qweight, layer.qzeros, layer.scales
unpacked_qweight, unpacked_qzeros = unpack_4bit_to_32bit_signed(qweight, qzeros)
group_size = unpacked_qweight.shape[0] // scales.shape[0]
scales = scales.repeat_interleave(group_size, dim=0)
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
unpacked_qweight = (unpacked_qweight - unpacked_qzeros) * scales
return unpacked_qweight.T, unpacked_qzeros
def dequantize_qzeros(layer):
qzeros = layer.qzeros
unpacked_qzeros = unpack_qzeros(qzeros)
group_size = layer.group_size
unpacked_qzeros = unpacked_qzeros.repeat_interleave(group_size, dim=0)
return unpacked_qzeros
__all__ = ["QuantLinear", "dequantize_weight"]
import math
from logging import getLogger
import numpy as np
import torch
from gekko import GEKKO
from torch import nn
logger = getLogger(__name__)
try:
import cQIGen as qinfer
except ImportError as e:
exception_qinfer = e
class FakeQInfer:
def __getattr__(self, name):
raise ImportError(f"cQIGen is not installed or not correctly installed. {exception_qinfer}")
def mem_model(N, M, T, mu, tu, bits, l1, p, gs):
m = GEKKO() # create GEKKO model
# cinfergen if bits==3:
# tu = tu*3
B = m.Const(value=bits)
TP = m.Const(value=T // p)
k = m.Var(1, integer=True, lb=1)
z = m.Var(1, integer=True, lb=1)
w = m.Var(1, integer=True, lb=1)
y = m.Var(1, integer=True, lb=1)
mb = m.Var(mu, integer=True, lb=1)
if gs != -1:
gg = m.Var(1, integer=True, lb=1)
tb = m.Var(tu, integer=True, lb=1, ub=int(T / p))
L = m.Var(integer=True, lb=0, ub=l1)
m.Equation(L == 32 * mb * N + B * mb * tb + 32 * tb * N)
m.Equation(mb * k == M)
if gs != -1:
m.Equation(gs * gg == mb)
# m.Equation(tb * z == T)
m.Equation(tb * z == TP)
m.Equation(mu * w == mb)
m.Equation(tu * y == tb)
# m.Equation(tb * v == tt)
m.Maximize(L)
m.options.SOLVER = 1
m.solver_options = [
"minlp_maximum_iterations 1000", # minlp iterations with integer solution
"minlp_max_iter_with_int_sol 10", # treat minlp as nlp
"minlp_as_nlp 0", # nlp sub-problem max iterations
"nlp_maximum_iterations 100", # 1 = depth first, 2 = breadth first
"minlp_branch_method 2", # maximum deviation from whole number
"minlp_integer_tol 0.00", # covergence tolerance
"minlp_gap_tol 0.01",
]
try:
m.solve(disp=False)
except Exception:
try:
m.solver_options = [
"minlp_maximum_iterations 1000", # minlp iterations with integer solution
"minlp_max_iter_with_int_sol 10", # treat minlp as nlp
"minlp_as_nlp 0", # nlp sub-problem max iterations
"nlp_maximum_iterations 100", # 1 = depth first, 2 = breadth first
"minlp_branch_method 1", # maximum deviation from whole number
"minlp_integer_tol 0.00", # covergence tolerance
"minlp_gap_tol 0.01",
]
m.solve(disp=False)
except Exception:
# mytb = T//p
mytb = tu
if gs != -1:
mymb = gs
while 32 * (mymb + gs) * N + bits * (mymb + gs) * mytb + 32 * mytb * N < l1:
mymb += gs
while M % mymb != 0:
mymb -= gs
return (int(mymb), int(mytb))
else:
mymb = mu
while 32 * (mymb + mu) * N + bits * (mymb + mu) * mytb + 32 * mytb * N < l1:
mymb += mu
while M % mymb != 0:
mymb -= mu
return (int(mymb), int(mytb))
return (int(mb.value[0]), int(tb.value[0]))
params = {}
def compute_reductions(x, gs=-1, cpp=True):
if cpp:
if len(x.shape) != 1:
rows, cols = x.shape
else:
rows = 1
cols = x.shape[0]
if gs == -1:
out = torch.zeros(rows).float().contiguous()
mygs = cols
else:
out = torch.zeros(rows, cols // gs).float().contiguous()
mygs = gs
qinfer.compute_reduction_cpp(x, out, rows, cols, mygs)
return out
if gs == -1:
if len(x.shape) != 1:
return torch.sum(x, 1)
else:
return torch.sum(x)
else:
if len(x.shape) != 1:
rows, cols = x.shape
out = torch.zeros(rows, cols // gs).float().contiguous()
for i in range(cols // gs):
out[:, i] = torch.sum(x[:, i * gs : (i + 1) * gs], 1)
return out
else:
cols = x.shape[0]
out = torch.zeros(cols // gs).float().contiguous()
for i in range(cols // gs):
out[i] = torch.sum(x[i * gs : (i + 1) * gs])
return out
def process_zeros_scales(zeros, scales, bits, M):
if zeros.dtype != torch.float32:
new_zeros = torch.zeros_like(scales).float().contiguous()
if bits == 4:
qinfer.unpack_zeros4(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 2:
qinfer.unpack_zeros2(zeros, new_zeros, new_zeros.shape[0], new_zeros.shape[1])
elif bits == 3:
logger.info("Unpacking zeros for 3 bits")
new_scales = scales.contiguous()
else:
if scales.shape[1] != M:
new_scales = scales.transpose(0, 1).contiguous()
else:
new_scales = scales.contiguous()
if zeros.shape[1] != M:
new_zeros = zeros.transpose(0, 1).contiguous()
else:
new_zeros = zeros.contiguous()
return new_zeros, new_scales
class QuantLinear(nn.Module):
QUANT_TYPE = "qigen"
def __init__(
self,
bits,
group_size,
infeatures,
outfeatures,
bias=None,
trainable=False,
hint=1,
p=8,
l1=2**18,
):
super().__init__()
if bits not in [2, 4]:
raise NotImplementedError("Only 2,4 bits are supported.")
if trainable:
raise NotImplementedError("Qigen kernel does not support training.")
self.bits = bits
self.infeatures = infeatures
self.outfeatures = outfeatures
n = hint
m = self.infeatures
t = self.outfeatures
# registers for now are fixed
if bits == 3:
packed = 32
mu = 32
tu = 32
else:
packed = 32 // bits
mu = 16
tu = 32
global params
if (m, t) in params:
mb = params[(m, t)][0]
tb = params[(m, t)][1]
else:
mb, tb = mem_model(n, m, t, mu, tu, bits, l1, p, group_size)
params[(m, t)] = (mb, tb)
split = np.ones(p)
split = split * tb
while np.sum(split) < t:
split = split + tb
idx = p - 1
while np.sum(split) > t:
split[idx] = split[idx] - tb
idx = idx - 1
assert np.sum(split) == t
split = split.astype(int)
self.tt = int(split[0])
if split[0] == split[-1]:
self.cutoff = int(p + 1)
else:
self.cutoff = int(idx + 1)
self.mb = mb # // packed
self.tb = tb
self.group_size = group_size
self.register_buffer("bias", torch.zeros(self.outfeatures))
self.register_buffer(
"zeros",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=torch.float32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=torch.float32,
),
)
if bits == 4:
self.register_buffer(
"qweight",
torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous(),
)
elif bits == 3:
self.register_buffer(
"qweight",
torch.zeros(int(self.infeatures // packed * 3 * self.outfeatures)).int().contiguous(),
)
elif bits == 2:
self.register_buffer(
"qweight",
torch.zeros(int(self.infeatures // packed * self.outfeatures)).int().contiguous(),
)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
x = x.reshape((-1, x.shape[-1])).to(torch.float32)
B = x.shape[0]
new_x = x.T.contiguous()
out = torch.zeros((B, self.outfeatures), dtype=torch.float32)
sums = compute_reductions(x, gs=self.group_size, cpp=True).contiguous()
if self.group_size == -1:
if self.bits == 4:
qinfer.forward4(
new_x,
self.qweight,
out,
self.bias,
self.scales,
self.zeros,
sums,
B,
self.infeatures,
self.outfeatures,
B,
self.mb,
self.tb,
self.tt,
self.cutoff,
)
elif self.bits == 2:
qinfer.forward2(
new_x,
self.qweight,
out,
self.bias,
self.scales,
self.zeros,
sums,
B,
self.infeatures,
self.outfeatures,
B,
self.mb,
self.tb,
self.tt,
self.cutoff,
)
elif self.bits == 3:
qinfer.forward3(
new_x,
self.qweight,
out,
self.bias,
self.scales,
self.zeros,
sums,
B,
self.infeatures,
self.outfeatures,
B,
self.mb,
self.tb,
self.tt,
self.cutoff,
)
else:
if self.bits == 4:
qinfer.forward_gs4(
new_x,
self.qweight,
out,
self.bias,
self.scales,
self.zeros,
sums,
B,
self.infeatures,
self.outfeatures,
B,
self.mb,
self.tb,
self.tt,
self.group_size,
self.cutoff,
)
elif self.bits == 2:
qinfer.forward_gs2(
new_x,
self.qweight,
out,
self.bias,
self.scales,
self.zeros,
sums,
B,
self.infeatures,
self.outfeatures,
B,
self.mb,
self.tb,
self.tt,
self.group_size,
self.cutoff,
)
elif self.bits == 3:
qinfer.forward_gs3(
new_x,
self.qweight,
out,
self.bias,
self.scales,
self.zeros,
sums,
B,
self.infeatures,
self.outfeatures,
B,
self.mb,
self.tb,
self.tt,
self.group_size,
self.cutoff,
)
return out.reshape(out_shape)
import math
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import transformers
from ..triton_utils.mixin import TritonModuleMixin
logger = getLogger(__name__)
try:
from ..triton_utils.kernels import (
QuantLinearFunction,
QuantLinearInferenceOnlyFunction,
quant_matmul_248,
quant_matmul_inference_only_248,
transpose_quant_matmul_248,
)
except ImportError as e:
triton_import_exception = e
def error_raiser_triton(*args, **kwargs):
raise ValueError(
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}"
)
class FakeTriton:
def __getattr__(self, name):
raise ImportError(
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}"
)
quant_matmul_248 = error_raiser_triton
transpose_quant_matmul_248 = error_raiser_triton
quant_matmul_inference_only_248 = error_raiser_triton
QuantLinearFunction = FakeTriton
QuantLinearInferenceOnlyFunction = FakeTriton
class QuantLinear(nn.Module, TritonModuleMixin):
QUANT_TYPE = "triton"
def __init__(self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
if infeatures % 32 != 0 or outfeatures % 32 != 0:
raise NotImplementedError("in_feature and out_feature must be divisible by 32.")
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2**self.bits - 1
self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=torch.float16,
),
)
self.register_buffer(
"g_idx",
torch.tensor([i // self.group_size for i in range(infeatures)], dtype=torch.int32),
)
if bias:
self.register_buffer("bias", torch.zeros((outfeatures), dtype=torch.float16))
else:
self.bias = None
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round((W[:, idx] + scale_zeros[self.g_idx[idx]]) / self.scales[self.g_idx[idx]]).to(torch.int)[
:, None
]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
i = 0
row = 0
qweight = np.zeros((intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
quant_linear_fn = QuantLinearFunction if self.trainable else QuantLinearInferenceOnlyFunction
out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, cls):
continue
k = m.infeatures
n = m.outfeatures
if (k, n) not in kn_values:
kn_values[(k, n)] = (
m.qweight,
m.scales,
m.qzeros,
m.g_idx,
m.bits,
m.maxq,
)
logger.info(f"Found {len(kn_values)} unique KN Linear values.")
logger.info("Warming up autotune cache ...")
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2**m
for (k, n), (
qweight,
scales,
qzeros,
g_idx,
bits,
maxq,
) in kn_values.items():
if transpose:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
a = torch.randn(m, n, dtype=torch.float16, device=model.device)
transpose_quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
else:
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_inference_only_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values
__all__ = ["QuantLinear"]
import math
from logging import getLogger
import numpy as np
import torch
import torch.nn as nn
import transformers
from ..triton_utils.mixin import TritonModuleMixin
logger = getLogger(__name__)
try:
from ..triton_utils.dequant import QuantLinearFunction, quant_matmul_248
except ImportError as e:
triton_import_exception = e
def error_raiser_triton(*args, **kwargs):
raise ValueError(
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}"
)
class FakeTriton:
def __getattr__(self, name):
raise ImportError(
f"Trying to use the triton backend, but could not import triton dependencies with the following error: {triton_import_exception}"
)
quant_matmul_248 = error_raiser_triton
QuantLinearFunction = FakeTriton
QuantLinearInferenceOnlyFunction = FakeTriton
class QuantLinear(nn.Module, TritonModuleMixin):
"""
Triton v2 quantized linear layer.
Calls dequant kernel (see triton_utils/dequant) to dequantize the weights then uses
torch.matmul to compute the output whereas original `triton` quantized linear layer fused
dequant and matmul into single kernel.add()
"""
QUANT_TYPE = "tritonv2"
def __init__(
self, bits, group_size, infeatures, outfeatures, bias, trainable=False, **kwargs
):
super().__init__()
if bits not in [2, 4, 8]:
raise NotImplementedError("Only 2,4,8 bits are supported.")
if infeatures % 32 != 0 or outfeatures % 32 != 0:
raise NotImplementedError(
"in_feature and out_feature must be divisible by 32."
)
self.infeatures = infeatures
self.outfeatures = outfeatures
self.bits = bits
self.group_size = group_size if group_size != -1 else infeatures
self.maxq = 2**self.bits - 1
self.register_buffer(
"qweight",
torch.zeros((infeatures // 32 * self.bits, outfeatures), dtype=torch.int32),
)
self.register_buffer(
"qzeros",
torch.zeros(
(
math.ceil(infeatures / self.group_size),
outfeatures // 32 * self.bits,
),
dtype=torch.int32,
),
)
self.register_buffer(
"scales",
torch.zeros(
(math.ceil(infeatures / self.group_size), outfeatures),
dtype=torch.float16,
),
)
self.register_buffer(
"g_idx",
torch.tensor(
[i // self.group_size for i in range(infeatures)], dtype=torch.int32
),
)
if bias:
self.register_buffer(
"bias", torch.zeros((outfeatures), dtype=torch.float16)
)
else:
self.bias = None
self.trainable = trainable
def post_init(self):
pass
def pack(self, linear, scales, zeros, g_idx=None):
W = linear.weight.data.clone()
if isinstance(linear, nn.Conv2d):
W = W.flatten(1)
if isinstance(linear, transformers.pytorch_utils.Conv1D):
W = W.t()
self.g_idx = g_idx.clone() if g_idx is not None else self.g_idx
scales = scales.t().contiguous()
zeros = zeros.t().contiguous()
scale_zeros = zeros * scales
self.scales = scales.clone().half()
if linear.bias is not None:
self.bias = linear.bias.clone().half()
intweight = []
for idx in range(self.infeatures):
intweight.append(
torch.round(
(W[:, idx] + scale_zeros[self.g_idx[idx]])
/ self.scales[self.g_idx[idx]]
).to(torch.int)[:, None]
)
intweight = torch.cat(intweight, dim=1)
intweight = intweight.t().contiguous()
intweight = intweight.numpy().astype(np.uint32)
i = 0
row = 0
qweight = np.zeros(
(intweight.shape[0] // 32 * self.bits, intweight.shape[1]), dtype=np.uint32
)
while row < qweight.shape[0]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qweight[row] |= intweight[j] << (self.bits * (j - i))
i += 32 // self.bits
row += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qweight = qweight.astype(np.int32)
self.qweight = torch.from_numpy(qweight)
zeros -= 1
zeros = zeros.numpy().astype(np.uint32)
qzeros = np.zeros(
(zeros.shape[0], zeros.shape[1] // 32 * self.bits), dtype=np.uint32
)
i = 0
col = 0
while col < qzeros.shape[1]:
if self.bits in [2, 4, 8]:
for j in range(i, i + (32 // self.bits)):
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
i += 32 // self.bits
col += 1
else:
raise NotImplementedError("Only 2,4,8 bits are supported.")
qzeros = qzeros.astype(np.int32)
self.qzeros = torch.from_numpy(qzeros)
def forward(self, x):
out_shape = x.shape[:-1] + (self.outfeatures,)
quant_linear_fn = QuantLinearFunction
out = quant_linear_fn.apply(
x.reshape(-1, x.shape[-1]),
self.qweight,
self.scales,
self.qzeros,
self.g_idx,
self.bits,
self.maxq,
)
out = out.half().reshape(out_shape)
out = out + self.bias if self.bias is not None else out
return out
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
"""
Pre-tunes the quantized kernel
"""
from tqdm import tqdm
kn_values = {}
for _, m in model.named_modules():
if not isinstance(m, cls):
continue
k = m.infeatures
n = m.outfeatures
if (k, n) not in kn_values:
kn_values[(k, n)] = (
m.qweight,
m.scales,
m.qzeros,
m.g_idx,
m.bits,
m.maxq,
)
logger.info(f"Found {len(kn_values)} unique KN Linear values.")
logger.info("Warming up autotune cache ...")
with torch.no_grad():
for m in tqdm(range(0, math.ceil(math.log2(seqlen)) + 1)):
m = 2**m
for (k, n), (
qweight,
scales,
qzeros,
g_idx,
bits,
maxq,
) in kn_values.items():
a = torch.randn(m, k, dtype=torch.float16, device=model.device)
quant_matmul_248(a, qweight, scales, qzeros, g_idx, bits, maxq)
del kn_values
__all__ = ["QuantLinear"]
import builtins
import math
import time
from typing import Dict
import triton
# code based https://github.com/fpgaminer/GPTQ-triton
"""
Mostly the same as the autotuner in Triton, but with a few changes like using 40 runs instead of 100.
"""
class CustomizedTritonAutoTuner(triton.KernelInterface):
def __init__(
self,
fn,
arg_names,
configs,
key,
reset_to_zero,
prune_configs_by: Dict = None,
nearest_power_of_two: bool = False,
):
if not configs:
self.configs = [triton.Config({}, num_warps=4, num_stages=2)]
else:
self.configs = configs
self.key_idx = [arg_names.index(k) for k in key]
self.nearest_power_of_two = nearest_power_of_two
self.cache = {}
# hook to reset all required tensor to zeros before relaunching a kernel
self.hook = lambda args: 0
if reset_to_zero is not None:
self.reset_idx = [arg_names.index(k) for k in reset_to_zero]
def _hook(args):
for i in self.reset_idx:
args[i].zero_()
self.hook = _hook
self.arg_names = arg_names
# prune configs
if prune_configs_by:
perf_model, top_k = (
prune_configs_by["perf_model"],
prune_configs_by["top_k"],
)
if "early_config_prune" in prune_configs_by:
early_config_prune = prune_configs_by["early_config_prune"]
else:
perf_model, top_k, early_config_prune = None, None, None
self.perf_model, self.configs_top_k = perf_model, top_k
self.early_config_prune = early_config_prune
self.fn = fn
def _bench(self, *args, config, **meta):
# check for conflicts, i.e. meta-parameters both provided
# as kwargs and by the autotuner
conflicts = meta.keys() & config.kwargs.keys()
if conflicts:
raise ValueError(
f"Conflicting meta-parameters: {', '.join(conflicts)}."
" Make sure that you don't re-define auto-tuned symbols."
)
# augment meta-parameters with tunable ones
current = dict(meta, **config.kwargs)
def kernel_call():
if config.pre_hook:
config.pre_hook(self.nargs)
self.hook(args)
self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**current,
)
try:
# In testings using only 40 reps seems to be close enough and it appears to be what PyTorch uses
# PyTorch also sets fast_flush to True, but I didn't see any speedup so I'll leave the default
return triton.testing.do_bench(kernel_call, quantiles=(0.5, 0.2, 0.8), rep=40)
except triton.OutOfResources:
return (float("inf"), float("inf"), float("inf"))
def run(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
if len(self.configs) > 1:
key = tuple(args[i] for i in self.key_idx)
# This reduces the amount of autotuning by rounding the keys to the nearest power of two
# In my testing this gives decent results, and greatly reduces the amount of tuning required
if self.nearest_power_of_two:
key = tuple([2 ** int(math.log2(x) + 0.5) for x in key])
if key not in self.cache:
# prune configs
pruned_configs = self.prune_configs(kwargs)
bench_start = time.time()
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
bench_end = time.time()
self.bench_time = bench_end - bench_start
self.cache[key] = builtins.min(timings, key=timings.get)
self.hook(args)
self.configs_timings = timings
config = self.cache[key]
else:
config = self.configs[0]
self.best_config = config
if config.pre_hook is not None:
config.pre_hook(self.nargs)
return self.fn.run(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
def prune_configs(self, kwargs):
pruned_configs = self.configs
if self.early_config_prune:
pruned_configs = self.early_config_prune(self.configs, self.nargs)
if self.perf_model:
top_k = self.configs_top_k
if isinstance(top_k, float) and top_k <= 1.0:
top_k = int(len(self.configs) * top_k)
if len(pruned_configs) > top_k:
est_timing = {
config: self.perf_model(
**self.nargs,
**kwargs,
**config.kwargs,
num_stages=config.num_stages,
num_warps=config.num_warps,
)
for config in pruned_configs
}
pruned_configs = sorted(est_timing.keys(), key=lambda x: est_timing[x])[:top_k]
return pruned_configs
def warmup(self, *args, **kwargs):
self.nargs = dict(zip(self.arg_names, args))
for config in self.prune_configs(kwargs):
self.fn.warmup(
*args,
num_warps=config.num_warps,
num_stages=config.num_stages,
**kwargs,
**config.kwargs,
)
self.nargs = None
def autotune(configs, key, prune_configs_by=None, reset_to_zero=None, nearest_power_of_two=False):
def decorator(fn):
return CustomizedTritonAutoTuner(
fn,
fn.arg_names,
configs,
key,
reset_to_zero,
prune_configs_by,
nearest_power_of_two,
)
return decorator
def matmul248_kernel_config_pruner(configs, nargs):
"""
The main purpose of this function is to shrink BLOCK_SIZE_* when the corresponding dimension is smaller.
"""
m = max(2 ** int(math.ceil(math.log2(nargs["M"]))), 16)
n = max(2 ** int(math.ceil(math.log2(nargs["N"]))), 16)
k = max(2 ** int(math.ceil(math.log2(nargs["K"]))), 16)
used = set()
for config in configs:
block_size_m = min(m, config.kwargs["BLOCK_SIZE_M"])
block_size_n = min(n, config.kwargs["BLOCK_SIZE_N"])
block_size_k = min(k, config.kwargs["BLOCK_SIZE_K"])
group_size_m = config.kwargs["GROUP_SIZE_M"]
if (
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
) in used:
continue
used.add(
(
block_size_m,
block_size_n,
block_size_k,
group_size_m,
config.num_stages,
config.num_warps,
)
)
yield triton.Config(
{
"BLOCK_SIZE_M": block_size_m,
"BLOCK_SIZE_N": block_size_n,
"BLOCK_SIZE_K": block_size_k,
"GROUP_SIZE_M": group_size_m,
},
num_stages=config.num_stages,
num_warps=config.num_warps,
)
__all__ = ["autotune"]
import itertools
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
def make_dequant_configs(block_sizes, num_warps):
configs = []
for bs, ws in itertools.product(block_sizes, num_warps):
configs.append(triton.Config({"X_BLOCK": bs}, num_warps=ws))
return configs
DEFAULT_DEQUANT_CONFIGS = make_dequant_configs([128, 256, 512, 1024], [4, 8])
@triton.autotune(DEFAULT_DEQUANT_CONFIGS, key=["numels"])
@triton.jit
def dequant_kernel_248(
g_idx_ptr,
scales_ptr,
qweight_ptr,
qzeros_ptr,
out_ptr,
numels,
maxq: tl.constexpr,
bits: tl.constexpr,
outfeatures: tl.constexpr,
num_groups: tl.constexpr,
X_BLOCK: tl.constexpr,
):
# Block indexing
xoffset = tl.program_id(0) * X_BLOCK
x_index = xoffset + tl.arange(0, X_BLOCK)
xmask = x_index < numels
row_idx = x_index // outfeatures
col_idx = x_index % outfeatures
elements_per_feature: tl.constexpr = 32 // bits
# Load parameters
g_idx = tl.load(g_idx_ptr + (row_idx), None, eviction_policy="evict_last")
qweights = tl.load(
qweight_ptr + (col_idx + (outfeatures * (row_idx // elements_per_feature))),
None,
)
wf_weights = (row_idx % elements_per_feature) * bits
wf_zeros = (col_idx % elements_per_feature) * bits
tmp1 = g_idx + num_groups
tmp2 = g_idx < 0
tl.device_assert(g_idx >= 0, "index out of bounds: 0 <= tmp0 < 0")
groups = tl.where(tmp2, tmp1, g_idx) # tmp3 are g_idx
scales = tl.load(scales_ptr + (col_idx + (outfeatures * groups)), None).to(
tl.float32
)
# Unpack weights
weights = qweights >> wf_weights # bit shift qweight
weights = weights & maxq
# Unpack zeros
qzero_ncols: tl.constexpr = outfeatures // elements_per_feature
qzeros = tl.load(
qzeros_ptr + ((qzero_ncols * groups) + (col_idx // elements_per_feature)),
None,
eviction_policy="evict_last",
)
zeros = qzeros >> wf_zeros
zeros = zeros & maxq
# Dequantize
zeros = zeros + 1
weights = weights - zeros
weights = weights.to(tl.float32)
weights = scales * weights
tl.store(out_ptr + (x_index), weights, mask=xmask)
def dequant248(qweight, scales, qzeros, g_idx, bits, maxq=None):
"""
Launcher for triton dequant kernel. Only valid for bits = 2, 4, 8
"""
num_groups = scales.shape[0]
outfeatures = scales.shape[1]
infeatures = g_idx.shape[0]
out = torch.empty((infeatures, outfeatures), device="cuda", dtype=torch.float16)
numels = out.numel()
maxq = 2**bits - 1 if maxq is None else maxq
grid = lambda meta: (triton.cdiv(numels, meta["X_BLOCK"]),) # noqa: E731
dequant_kernel_248[grid](
g_idx,
scales,
qweight,
qzeros,
out,
numels,
maxq=maxq,
bits=bits,
outfeatures=outfeatures,
num_groups=num_groups,
)
return out
def quant_matmul_248(
input, qweight, scales, qzeros, g_idx, bits, maxq=None, transpose=False
):
W = dequant248(qweight, scales, qzeros, g_idx, bits, maxq=maxq)
if transpose:
return input @ W.t()
return input @ W
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = quant_matmul_248(
grad_output, qweight, scales, qzeros, g_idx, bits, maxq, transpose=True
)
return grad_input, None, None, None, None, None, None
from logging import getLogger
import torch
import triton
import triton.language as tl
from torch.cuda.amp import custom_bwd, custom_fwd
from . import custom_autotune
logger = getLogger(__name__)
# code based https://github.com/fpgaminer/GPTQ-triton
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
prune_configs_by={
"early_config_prune": custom_autotune.matmul248_kernel_config_pruner,
"perf_model": None,
"top_k": None,
},
)
@triton.jit
def quant_matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, K) float16
B is of shape (K//8, N) int32
C is of shape (M, N) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_k[:, None] // infearure_per_bits) * stride_bk + offs_bn[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_k
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_bn[None, :]
zeros_ptrs = zeros_ptr + (offs_bn[None, :] // infearure_per_bits)
shifter = (offs_k % infearure_per_bits) * bits
zeros_shifter = (offs_bn % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
for k in range(0, num_pid_k):
g_idx = tl.load(g_ptrs)
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs + g_idx[:, None] * stride_scales) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs + g_idx[:, None] * stride_zeros) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_K)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_K
b_ptrs += (BLOCK_SIZE_K // infearure_per_bits) * stride_bk
g_ptrs += BLOCK_SIZE_K
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bn[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@custom_autotune.autotune(
configs=[
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
},
num_stages=4,
num_warps=4,
),
triton.Config(
{
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
},
num_stages=2,
num_warps=8,
),
],
key=["M", "N", "K"],
nearest_power_of_two=True,
)
@triton.jit
def transpose_quant_matmul_248_kernel(
a_ptr,
b_ptr,
c_ptr,
scales_ptr,
zeros_ptr,
g_ptr,
M,
N,
K,
bits,
maxq,
stride_am,
stride_ak,
stride_bk,
stride_bn,
stride_cm,
stride_cn,
stride_scales,
stride_zeros,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
GROUP_SIZE_M: tl.constexpr,
):
"""
Compute the matrix multiplication C = A x B.
A is of shape (M, N) float16
B is of shape (K//8, N) int32
C is of shape (M, K) float16
scales is of shape (G, N) float16
zeros is of shape (G, N) float16
g_ptr is of shape (K) int32
"""
infearure_per_bits = 32 // bits
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
num_pid_k = tl.cdiv(K, BLOCK_SIZE_K)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_k
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + (pid % group_size_m)
pid_k = (pid % num_pid_in_group) // group_size_m
offs_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_bk = pid_k * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
offs_n = tl.arange(0, BLOCK_SIZE_N)
a_ptrs = a_ptr + (offs_am[:, None] * stride_am + offs_n[None, :] * stride_ak) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
a_mask = offs_am[:, None] < M
# b_ptrs is set up such that it repeats elements along the K axis 8 times
b_ptrs = b_ptr + (
(offs_bk[:, None] // infearure_per_bits) * stride_bk + offs_n[None, :] * stride_bn
) # (BLOCK_SIZE_K, BLOCK_SIZE_N)
g_ptrs = g_ptr + offs_bk
g_idx = tl.load(g_ptrs)
# shifter is used to extract the N bits of each element in the 32-bit word from B
scales_ptrs = scales_ptr + offs_n[None, :] + g_idx[:, None] * stride_scales
zeros_ptrs = zeros_ptr + (offs_n[None, :] // infearure_per_bits) + g_idx[:, None] * stride_zeros
shifter = (offs_bk % infearure_per_bits) * bits
zeros_shifter = (offs_n % infearure_per_bits) * bits
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_K), dtype=tl.float32)
for k in range(0, num_pid_n):
# Fetch scales and zeros; these are per-outfeature and thus reused in the inner loop
scales = tl.load(scales_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = tl.load(zeros_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N,)
zeros = (zeros >> zeros_shifter[None, :]) & maxq
zeros = zeros + 1
a = tl.load(a_ptrs, mask=a_mask, other=0.0) # (BLOCK_SIZE_M, BLOCK_SIZE_N)
b = tl.load(b_ptrs) # (BLOCK_SIZE_K, BLOCK_SIZE_N), but repeated
# Now we need to unpack b (which is N-bit values) into 32-bit values
b = (b >> shifter[:, None]) & maxq # Extract the N-bit values
b = (b - zeros) * scales # Scale and shift
b = tl.trans(b)
accumulator += tl.dot(a, b)
a_ptrs += BLOCK_SIZE_N
b_ptrs += BLOCK_SIZE_N
scales_ptrs += BLOCK_SIZE_N
zeros_ptrs += BLOCK_SIZE_N // infearure_per_bits
c_ptrs = c_ptr + stride_cm * offs_am[:, None] + stride_cn * offs_bk[None, :]
c_mask = (offs_am[:, None] < M) & (offs_bk[None, :] < K)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit
def silu(x):
return x * tl.sigmoid(x)
def quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=input.dtype)
grid = lambda META: ( # noqa: E731
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
quant_matmul_248_kernel[grid](
input,
qweight,
output,
scales.to(input.dtype),
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output
def transpose_quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output_dim = (qweight.shape[0] * 32) // bits
output = torch.empty((input.shape[0], output_dim), device=input.device, dtype=input.dtype)
grid = lambda META: ( # noqa: E731
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(output_dim, META["BLOCK_SIZE_K"]),
)
transpose_quant_matmul_248_kernel[grid](
input,
qweight,
output,
scales.to(input.dtype),
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
output_dim,
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output
class QuantLinearFunction(torch.autograd.Function):
@staticmethod
@custom_fwd
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
ctx.save_for_backward(qweight, scales, qzeros, g_idx)
ctx.bits, ctx.maxq = bits, maxq
return output
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
qweight, scales, qzeros, g_idx = ctx.saved_tensors
bits, maxq = ctx.bits, ctx.maxq
grad_input = None
if ctx.needs_input_grad[0]:
grad_input = transpose_quant_matmul_248(grad_output, qweight, scales, qzeros, g_idx, bits, maxq)
return grad_input, None, None, None, None, None, None
def quant_matmul_inference_only_248(input, qweight, scales, qzeros, g_idx, bits, maxq):
with torch.cuda.device(input.device):
output = torch.empty((input.shape[0], qweight.shape[1]), device=input.device, dtype=torch.float16)
grid = lambda META: ( # noqa: E731
triton.cdiv(input.shape[0], META["BLOCK_SIZE_M"]) * triton.cdiv(qweight.shape[1], META["BLOCK_SIZE_N"]),
)
quant_matmul_248_kernel[grid](
input,
qweight,
output,
scales,
qzeros,
g_idx,
input.shape[0],
qweight.shape[1],
input.shape[1],
bits,
maxq,
input.stride(0),
input.stride(1),
qweight.stride(0),
qweight.stride(1),
output.stride(0),
output.stride(1),
scales.stride(0),
qzeros.stride(0),
)
return output
class QuantLinearInferenceOnlyFunction(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input, qweight, scales, qzeros, g_idx, bits, maxq):
output = quant_matmul_248(input, qweight, scales, qzeros, g_idx, bits, maxq)
return output
class TritonModuleMixin:
@classmethod
def warmup(cls, model, transpose=False, seqlen=2048):
pass
The codes in this directory are mainly referenced from @qwopqwop200 's [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa/tree/cuda), which itself is based on [gptq](https://github.com/IST-DASLab/gptq)
\ No newline at end of file
from .config import (
CHECKPOINT_FORMAT,
CHECKPOINT_FORMAT_FIELD,
CHECKPOINT_FORMAT_FIELD_COMPAT_MARLIN,
QUANT_CONFIG_FILENAME,
QUANT_METHOD,
QUANT_METHOD_FIELD,
BaseQuantizeConfig,
)
from .gptq import GPTQ
from .quantizer import Quantizer, quantize
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment