Unverified Commit 0fbfc4b8 authored by CHU Tianxiang's avatar CHU Tianxiang Committed by GitHub
Browse files

Add GPTQ support (#916)

parent c06170cc
...@@ -84,7 +84,7 @@ if __name__ == '__main__': ...@@ -84,7 +84,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None) default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
......
...@@ -244,7 +244,7 @@ if __name__ == "__main__": ...@@ -244,7 +244,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None) default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",
......
...@@ -77,3 +77,15 @@ void squeezellm_gemm( ...@@ -77,3 +77,15 @@ void squeezellm_gemm(
torch::Tensor mat, torch::Tensor mat,
torch::Tensor mul, torch::Tensor mul,
torch::Tensor lookup_table); torch::Tensor lookup_table);
torch::Tensor gptq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales,
torch::Tensor b_g_idx,
bool use_exllama);
void gptq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm);
...@@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -52,8 +52,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Quantization ops // Quantization ops
ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ"); ops.def("awq_gemm", &awq_gemm, "Quantized GEMM for AWQ");
#endif #endif
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
// Cache ops // Cache ops
......
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _compat_cuh
#define _compat_cuh
namespace vllm {
namespace gptq {
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
{
unsigned int * address_as_ui = (unsigned int *) ((char *)address - ((size_t)address & 2));
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
__half_raw hsum;
hsum.x = (size_t)address & 2 ? (old >> 16) : (old & 0xffff);
half tmpres = __hadd(hsum, val);
hsum = __half_raw(tmpres);
old = (size_t)address & 2 ? (old & 0xffff) | (hsum.x << 16) : (old & 0xffff0000) | hsum.x;
old = atomicCAS(address_as_ui, assumed, old);
}
while (assumed != old);
}
// atomicAdd for half2 types
__device__ __forceinline__ void atomicAdd_half2(half2* address, half2 val)
{
unsigned int* address_as_ui = (unsigned int*)address;
unsigned int old = *address_as_ui;
unsigned int assumed;
do
{
assumed = old;
half2 old_val = *((half2*)&old);
half2 new_val = __hadd2(old_val, val);
old = atomicCAS(address_as_ui, assumed, *((unsigned int*)&new_val));
}
while (assumed != old);
}
//
#if defined(__CUDA_ARCH__) || defined(USE_ROCM)
#if __CUDA_ARCH__ < 700 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half* address, half val) { atomicAdd_half(address, val); }
#if __CUDA_ARCH__ < 600 || defined(USE_ROCM)
__device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd_half2(address, val); }
#endif
#endif
#endif
} // namespace gptq
} // namespace vllm
#endif
/*
Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turboderp/exllama
*/
#ifndef _matrix_view_cuh
#define _matrix_view_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
__device__ __forceinline__ void item4(half (&items)[4], int row, int column) const
{
half2* ptr = (half2*) item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __low2half(i01);
items[1] = __high2half(i01);
items[2] = __low2half(i23);
items[3] = __high2half(i23);
}
__device__ __forceinline__ void item4_f(float (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2float(__low2half(i01));
items[1] = __half2float(__high2half(i01));
items[2] = __half2float(__low2half(i23));
items[3] = __half2float(__high2half(i23));
}
__device__ __forceinline__ void item4_h2(half2 (&items)[4], int row, int column) const
{
half2* ptr = (half2*)item_ptr(row, column);
half2 i01 = ptr[0];
half2 i23 = ptr[1];
items[0] = __half2half2(__low2half(i01));
items[1] = __half2half2(__high2half(i01));
items[2] = __half2half2(__low2half(i23));
items[3] = __half2half2(__high2half(i23));
}
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
__device__ __forceinline__ void set4(int row, int column, half v0, half v1, half v2, half v3)
{
half2 v01 = __halves2half2(v0, v1);
half2 v23 = __halves2half2(v2, v3);
half2* ptr = (half2*) item_ptr(row, column);
ptr[0] = v01;
ptr[1] = v23;
}
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
__device__ __forceinline__ void item2(int (&items)[2], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
}
__device__ __forceinline__ void item4(int (&items)[4], int row, int column) const
{
int shift = (column & 0x07) * 4;
uint32_t d = data[row * width / 8 + column / 8] >> shift;
items[0] = d & 0x0f;
items[1] = (d >> 4) & 0x0f;
items[2] = (d >> 8) & 0x0f;
items[3] = (d >> 12) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
} // namespace gptq
} // namespace vllm
#endif
This diff is collapsed.
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_4_cuh
#define _qdq_4_cuh
#include "qdq_util.cuh"
namespace vllm {
namespace gptq {
// Permutation:
//
// 77775555 33331111 66664444 22220000
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
uint32_t qa = q[0];
uint32_t qb = 0;
#pragma unroll
for (int i = 0; i < 4; i++)
{
uint32_t qa0 = qa & 0x0f;
uint32_t qa1 = (qa & 0xf0) >> 4;
qa >>= 8;
qb |= (qa1 << (i * 4 + 16));
qb |= (qa0 << (i * 4));
}
q[0] = qb;
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
const uint32_t c0 = 0x64006400;
const half y16_ = __float2half_rn(1.0f / 16.0f);
const half2 y16 = __halves2half2(y16_, y16_);
const half z1_ = __float2half_rn(-1024.0f - 8.0f);
const half z16_ = __float2half_rn(-1024.0f / 16.0f - 8.0f);
const half2 z1 = __halves2half2(z1_, z1_);
const half2 z16 = __halves2half2(z16_, z16_);
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2(q[ 0], q[ 1]) + 1024
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2(q[ 2], q[ 3]) * 16 + 1024
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2(q[ 4], q[ 5]) + 1024
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2(q[ 6], q[ 7]) * 16 + 1024
dq[0] = __hadd2(q0.as_half2, z1);
dq[1] = __hfma2(q1.as_half2, y16, z16);
dq[2] = __hadd2(q2.as_half2, z1);
dq[3] = __hfma2(q3.as_half2, y16, z16);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1z16)[2],
half2 (&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
half2 scale2 = __half2half2(scale);
z1z16[0] = __hmul2(scale2, __half2half2(z1.as_half));
z1z16[1] = __hmul2(scale2, __half2half2(z16));
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __hmul2(scale2, __half2half2(y1));
y1y16[1] = __hmul2(scale2, __half2half2(y16));
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1z16)[2],
half2(&y1y16)[2]
)
{
half_uint16 z1(0xe400 | zero); // half(-1024.0f - zero);
half z16 = __hsub(__int2half_rn(-64), __int2half_rn(zero));
z1z16[0] = __half2half2(z1.as_half);
z1z16[1] = __half2half2(z16);
const half y1 = __float2half_rn(1.0f);
const half y16 = __float2half_rn(1.0f / 16.0f);
y1y16[0] = __half2half2(y1);
y1y16[1] = __half2half2(y16);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1z16)[2],
half2 (&y1y16)[2],
int stride,
bool scaled
)
{
const uint32_t c0 = 0x64006400;
uint32_t qa = q_0;
half2_uint32 q0((qa & 0x000f000f) | c0); // half2( q[0] + 1024, q[1] + 1024 )
half2_uint32 q1((qa & 0x00f000f0) | c0); // half2( q[2] * 16 + 1024, q[3] * 16 + 1024 )
qa >>= 8;
half2_uint32 q2((qa & 0x000f000f) | c0); // half2( q[4] + 1024, q[5] + 1024 )
half2_uint32 q3((qa & 0x00f000f0) | c0); // half2( q[6] * 16 + 1024, q[7] * 16 + 1024 )
if (scaled)
{
dq[0] = __hfma2(q0.as_half2, y1y16[0], z1z16[0]); // half2( q[0] * s - z * s, q[1] * s - z * s)
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] * s - z * s, q[3] * s - z * s)
dq[2] = __hfma2(q2.as_half2, y1y16[0], z1z16[0]);
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]);
}
else
{
dq[0] = __hadd2(q0.as_half2, z1z16[0]); // half2( q[0] - z, q[1] - z )
dq[1] = __hfma2(q1.as_half2, y1y16[1], z1z16[1]); // half2( q[2] - z, q[3] - z )
dq[2] = __hadd2(q2.as_half2, z1z16[0]); // half2( q[4] - z, q[5] - z )
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
} // namespace gptq
} // namespace vllm
#else
namespace vllm {
namespace gptq {
__forceinline__ __device__ void shuffle_4bit_8
(
uint32_t* q,
int stride
)
{
}
__forceinline__ __device__ void dequant_4bit_8
(
const uint32_t q_0,
half2 (&dq)[4],
int stride
)
{
half dqh[8];
for (int i = 0; i < 8; i++) dqh[i] = dq_ns(exb(q_0, i * 4, 0x0f), 8);
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero_scale
(
const uint32_t zero,
const half scale,
half2 (&z1)[2],
half2 (&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z = __hmul(z, scale);
z1[0] = __half2half2(z);
y1[0] = __half2half2(scale);
}
__forceinline__ __device__ void dequant_4bit_8_prep_zero
(
const uint32_t zero,
half2(&z1)[2],
half2(&y1)[2]
)
{
half z = __int2half_rn(-((int)zero));
z1[0] = __half2half2(z);
}
__forceinline__ __device__ void dequant_4bit_8_gptq
(
const uint32_t q_0,
half2 (&dq)[4],
half2 (&z1)[2],
half2 (&y1)[2],
int stride,
bool scaled
)
{
half2 dqh2[8];
uint32_t qa = q_0;
for (int i = 0; i < 4; i++)
{
half d0 = __int2half_rn(qa & 0x0f); qa >>= 4;
half d1 = __int2half_rn(qa & 0x0f); qa >>= 4;
dqh2[i] = __halves2half2(d0, d1);
}
if (scaled)
{
dq[0] = __hfma2(dqh2[0], y1[0], z1[0]);
dq[1] = __hfma2(dqh2[1], y1[0], z1[0]);
dq[2] = __hfma2(dqh2[2], y1[0], z1[0]);
dq[3] = __hfma2(dqh2[3], y1[0], z1[0]);
}
else
{
dq[0] = __hadd2(dqh2[0], z1[0]);
dq[1] = __hadd2(dqh2[1], z1[0]);
dq[2] = __hadd2(dqh2[2], z1[0]);
dq[3] = __hadd2(dqh2[3], z1[0]);
}
}
} // namespace gptq
} // namespace vllm
#endif
/*
Copied from https://github.com/turboderp/exllamav2
*/
#ifndef _qdq_util_cuh
#define _qdq_util_cuh
namespace vllm {
namespace gptq {
union half2_uint32
{
uint32_t as_uint32;
half2 as_half2;
__device__ half2_uint32(uint32_t val) : as_uint32(val) {}
__device__ half2_uint32(half2 val) : as_half2(val) {}
};
union half_uint16
{
uint16_t as_uint16;
half as_half;
__device__ half_uint16(uint16_t val) : as_uint16(val) {}
__device__ half_uint16(half val) : as_half(val) {}
};
// Max_scale premultiplied by 1/256
__forceinline__ __device__ half dq_scale(const int qs, const half max_scale)
{
int qs_i = qs + 1;
half qs_h = __int2half_rn(qs_i * qs_i);
qs_h = __hmul(qs_h, max_scale);
return qs_h;
}
__forceinline__ __device__ half dq(const int q, const int qzero, const half scale)
{
return __hmul(__int2half_rn(q - qzero), scale);
}
__forceinline__ __device__ half dq_ns(const int q, const int qzero)
{
//return __hsub(__int2half_rn(q), __int2half_rn(qzero));
return __int2half_rn(q - qzero);
}
__forceinline__ __device__ int exb(const uint32_t q, const int shift, const int mask)
{
return (int)((q >> shift) & mask);
}
__forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const int shift, const int mask)
{
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
} // namespace gptq
} // namespace vllm
#endif
...@@ -219,6 +219,7 @@ vllm_extension_sources = [ ...@@ -219,6 +219,7 @@ vllm_extension_sources = [
"csrc/activation_kernels.cu", "csrc/activation_kernels.cu",
"csrc/layernorm_kernels.cu", "csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu", "csrc/cuda_utils_kernels.cu",
"csrc/pybind.cpp", "csrc/pybind.cpp",
] ]
......
...@@ -142,7 +142,7 @@ class ModelConfig: ...@@ -142,7 +142,7 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "squeezellm"] supported_quantization = ["awq", "gptq", "squeezellm"]
rocm_not_supported_quantization = ["awq"] rocm_not_supported_quantization = ["awq"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
......
...@@ -179,7 +179,7 @@ class EngineArgs: ...@@ -179,7 +179,7 @@ class EngineArgs:
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=str,
choices=['awq', 'squeezellm', None], choices=['awq', 'gptq', 'squeezellm', None],
default=None, default=None,
help='Method used to quantize the weights') help='Method used to quantize the weights')
return parser return parser
......
...@@ -38,8 +38,9 @@ class LLM: ...@@ -38,8 +38,9 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
quantization: The method used to quantize the model weights. Currently, quantization: The method used to quantize the model weights. Currently,
we support "awq". If None, we assume the model weights are not we support "awq", "gptq" and "squeezellm". If None, we assume the
quantized and use `dtype` to determine the data type of the weights. model weights are not quantized and use `dtype` to determine the
data type of the weights.
revision: The specific model version to use. It can be a branch name, revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. a tag name, or a commit id.
tokenizer_revision: The specific tokenizer version to use. It can be a tokenizer_revision: The specific tokenizer version to use. It can be a
......
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -21,8 +21,10 @@ class LinearMethodBase(ABC): ...@@ -21,8 +21,10 @@ class LinearMethodBase(ABC):
"""Base class for different (maybe quantized) linear methods.""" """Base class for different (maybe quantized) linear methods."""
@abstractmethod @abstractmethod
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
"""Create weights for a linear layer.""" """Create weights for a linear layer."""
raise NotImplementedError raise NotImplementedError
...@@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase): ...@@ -46,10 +48,12 @@ class UnquantizedLinearMethod(LinearMethodBase):
def __init__(self, separate_bias_add: bool = False): def __init__(self, separate_bias_add: bool = False):
self.separate_bias_add = separate_bias_add self.separate_bias_add = separate_bias_add
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
weight = Parameter(torch.empty(output_size, output_size: int,
input_size, params_dtype: torch.dtype) -> Dict[str, Any]:
weight = Parameter(torch.empty(output_size_per_partition,
input_size_per_partition,
device=torch.cuda.current_device(), device=torch.cuda.current_device(),
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
...@@ -102,8 +106,10 @@ class ReplicatedLinear(torch.nn.Module): ...@@ -102,8 +106,10 @@ class ReplicatedLinear(torch.nn.Module):
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size, self.params_dtype) self.input_size, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items(): for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight) self.register_parameter(name, weight)
if bias: if bias:
self.bias = Parameter( self.bias = Parameter(
...@@ -168,8 +174,10 @@ class ColumnParallelLinear(torch.nn.Module): ...@@ -168,8 +174,10 @@ class ColumnParallelLinear(torch.nn.Module):
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_weights = self.linear_method.create_weights(
self.input_size, self.output_size_per_partition, self.params_dtype) self.input_size, self.output_size_per_partition, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items(): for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight) self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) set_weight_attrs(weight, {"weight_loader": self.weight_loader})
if bias: if bias:
...@@ -295,6 +303,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -295,6 +303,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
else: else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning( logger.warning(
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"MergedColumnParallelLinear, assume the weight is " "MergedColumnParallelLinear, assume the weight is "
...@@ -418,6 +428,8 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -418,6 +428,8 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_weight = loaded_weight.narrow(output_dim, start_idx, loaded_weight = loaded_weight.narrow(output_dim, start_idx,
shard_size) shard_size)
else: else:
ignore_warning = getattr(param, "ignore_warning", False)
if not ignore_warning:
logger.warning( logger.warning(
"Loading a weight without `output_dim` attribute in " "Loading a weight without `output_dim` attribute in "
"QKVParallelLinear, assume the weight is the same " "QKVParallelLinear, assume the weight is the same "
...@@ -481,8 +493,10 @@ class RowParallelLinear(torch.nn.Module): ...@@ -481,8 +493,10 @@ class RowParallelLinear(torch.nn.Module):
linear_method = UnquantizedLinearMethod() linear_method = UnquantizedLinearMethod()
self.linear_method = linear_method self.linear_method = linear_method
self.linear_weights = self.linear_method.create_weights( self.linear_weights = self.linear_method.create_weights(
self.input_size_per_partition, self.output_size, self.params_dtype) self.input_size_per_partition, self.output_size, self.input_size,
self.output_size, self.params_dtype)
for name, weight in self.linear_weights.items(): for name, weight in self.linear_weights.items():
if isinstance(weight, torch.Tensor):
self.register_parameter(name, weight) self.register_parameter(name, weight)
set_weight_attrs(weight, {"weight_loader": self.weight_loader}) set_weight_attrs(weight, {"weight_loader": self.weight_loader})
......
from typing import Type from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
_QUANTIZATION_CONFIG_REGISTRY = { _QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig, "awq": AWQConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
} }
......
...@@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -77,14 +77,16 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig): def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
if input_size % self.quant_config.group_size != 0: output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
if output_size % self.quant_config.pack_factor != 0: if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
"The output size is not aligned with the quantized " "The output size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
...@@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -92,8 +94,8 @@ class AWQLinearMethod(LinearMethodBase):
qweight = Parameter( qweight = Parameter(
torch.empty( torch.empty(
input_size, input_size_per_partition,
output_size // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
...@@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -108,8 +110,8 @@ class AWQLinearMethod(LinearMethodBase):
}) })
qzeros = Parameter( qzeros = Parameter(
torch.empty( torch.empty(
input_size // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size // self.quant_config.pack_factor, output_size_per_partition // self.quant_config.pack_factor,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
...@@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -124,8 +126,8 @@ class AWQLinearMethod(LinearMethodBase):
}) })
scales = Parameter( scales = Parameter(
torch.empty( torch.empty(
input_size // self.quant_config.group_size, input_size_per_partition // self.quant_config.group_size,
output_size, output_size_per_partition,
device="cuda", device="cuda",
dtype=params_dtype, dtype=params_dtype,
), ),
...@@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -142,7 +144,7 @@ class AWQLinearMethod(LinearMethodBase):
} }
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], weights: Dict[str, Any],
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = weights["qweight"]
......
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm._C import ops
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
Reference: https://arxiv.org/abs/2210.17323
"""
def __init__(
self,
weight_bits: int,
group_size: int,
desc_act: bool,
) -> None:
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.pack_factor = 32 // self.weight_bits
# exllama kernel v1 only supports 4 bit
if self.weight_bits != 4:
raise ValueError(
"Currently, only 4-bit weight quantization is supported for "
f"GPTQ, but got {self.weight_bits} bits.")
def __repr__(self) -> str:
return (f"GPTQConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half]
@classmethod
# Need to figure it out
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
return cls(weight_bits, group_size, desc_act)
def get_linear_method(self) -> "GPTQLinearMethod":
return GPTQLinearMethod(self)
def get_scaled_act_names(self) -> List[str]:
return []
class ExllamaState(Enum):
UNUSED = enum.auto()
UNINITIALIZED = enum.auto()
READY = enum.auto()
class GPTQLinearMethod(LinearMethodBase):
"""Linear method for GPTQ.
Args:
quant_config: The GPTQ quantization config.
"""
def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
def create_weights(
self,
input_size_per_partition: int,
output_size_per_partition: int,
input_size: int,
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
del output_size # Unused.
if input_size_per_partition % self.quant_config.group_size != 0:
raise ValueError(
"The input size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
"The output size is not aligned with the quantized "
"weight shape. This can be caused by too large "
"tensor parallel size.")
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
exllama_state = ExllamaState.UNINITIALIZED
scale_and_zero_size = input_size // group_size
scale_and_zero_input_dim = None
if input_size != input_size_per_partition and self.quant_config.group_size != -1:
# For act-order models, we cannot use Exllama for row parallel layer
if self.quant_config.desc_act:
exllama_state = ExllamaState.UNUSED
else:
# we need to partition qzeros and scales for exllama kernel
scale_and_zero_size = input_size_per_partition // group_size
scale_and_zero_input_dim = 0
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight, {
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
})
g_idx = Parameter(
torch.tensor(
[
i // self.quant_config.group_size
for i in range(input_size_per_partition)
],
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(g_idx, {"input_dim": 0, "ignore_warning": True})
qzeros = Parameter(
torch.empty(
scale_and_zero_size,
output_size_per_partition // self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qzeros, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
})
scales = Parameter(
torch.empty(
scale_and_zero_size,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(scales, {
"input_dim": scale_and_zero_input_dim,
"output_dim": 1,
})
return {
"qweight": qweight,
"g_idx": g_idx,
"qzeros": qzeros,
"scales": scales,
"exllama_state": exllama_state,
}
def apply_weights(self,
weights: Dict[str, Any],
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"]
out_shape = x.shape[:-1] + (qweight.shape[-1], )
reshaped_x = x.reshape(-1, x.shape[-1])
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
if weights["exllama_state"] == ExllamaState.UNINITIALIZED:
if self.quant_config.desc_act:
weights["g_idx"] = torch.argsort(weights["g_idx"]).to(
torch.int)
else:
weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY
ops.gptq_shuffle(weights["qweight"], weights["g_idx"])
output = ops.gptq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"],
weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY)
if bias is not None:
output = output + bias
return output.reshape(out_shape)
...@@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -67,17 +67,19 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
def __init__(self, quant_config: SqueezeLLMConfig): def __init__(self, quant_config: SqueezeLLMConfig):
self.quant_config = quant_config self.quant_config = quant_config
def create_weights(self, input_size: int, output_size: int, def create_weights(self, input_size_per_partition: int,
params_dtype: torch.dtype) -> Dict[str, torch.Tensor]: output_size_per_partition: int, input_size: int,
if input_size % self.quant_config.pack_factor != 0: output_size: int,
params_dtype: torch.dtype) -> Dict[str, Any]:
if input_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError( raise ValueError(
"The input size is not aligned with the quantized " "The input size is not aligned with the quantized "
"weight shape. This can be caused by too large " "weight shape. This can be caused by too large "
"tensor parallel size.") "tensor parallel size.")
qweight = Parameter( qweight = Parameter(
torch.empty( torch.empty(
input_size // self.quant_config.pack_factor, input_size_per_partition // self.quant_config.pack_factor,
output_size, output_size_per_partition,
device="cuda", device="cuda",
dtype=torch.int32, dtype=torch.int32,
), ),
...@@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase): ...@@ -108,7 +110,7 @@ class SqueezeLLMLinearMethod(LinearMethodBase):
} }
def apply_weights(self, def apply_weights(self,
weights: Dict[str, torch.Tensor], weights: Dict[str, Any],
x: torch.Tensor, x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor: bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = weights["qweight"] qweight = weights["qweight"]
......
...@@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module): ...@@ -332,11 +332,18 @@ class AquilaForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -355,11 +355,18 @@ class BaiChuanBaseForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment