Commit eb8e460c authored by nicodafagood's avatar nicodafagood
Browse files

update mygq

parent 23fdbb68
......@@ -92,7 +92,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq','myq', 'squeezellm', None],
choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32)
......
......@@ -258,7 +258,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization',
'-q',
choices=['awq', 'gptq','myq', 'squeezellm', None],
choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n",
......
......@@ -115,16 +115,16 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);
torch::Tensor myq_gemm(
torch::Tensor mygq_gemm(
torch::Tensor a,
torch::Tensor b_q_weight,
torch::Tensor b_myq_qzeros,
torch::Tensor b_myq_scales,
torch::Tensor b_mygq_qzeros,
torch::Tensor b_mygq_scales,
torch::Tensor b_g_idx,
bool use_exllama,
int bit);
void myq_shuffle(
void mygq_shuffle(
torch::Tensor q_weight,
torch::Tensor q_perm,
int bit);
......
......@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("myq_gemm", &myq_gemm, "Quantized GEMM for myq");
ops.def("myq_shuffle", &myq_shuffle, "Post processing for GPTQ");
ops.def("mygq_gemm", &mygq_gemm, "Quantized GEMM for mygq");
ops.def("mygq_shuffle", &mygq_shuffle, "Post processing for mygq");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def(
"moe_align_block_size",
......
......@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _compat_cuh
namespace vllm {
namespace myq {
namespace mygq {
// atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val)
......@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd
#endif
#endif
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
class MatrixView_half
{
......@@ -269,6 +269,6 @@ public:
}
};
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
// Permutation:
//
......@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16
dq[7] = __hfma2(q7.as_half2, y64, z64);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -4,7 +4,7 @@
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
// Permutation:
//
// v9997775 55333111 u8886664 44222000 (u, v lsb)
......@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32
dq[15] = __hadd2(q15.as_half2, z1);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
// Permutation:
//
// 77775555 33331111 66664444 22220000
......@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
}
__forceinline__ __device__ void dequant_4bit_8_myq
__forceinline__ __device__ void dequant_4bit_8_mygq
(
const uint32_t q_0,
half2 (&dq)[4],
......@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
}
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh"
namespace vllm {
namespace myq {
namespace mygq {
__forceinline__ __device__ void shuffle_8bit_4
(
......@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _qdq_util_cuh
namespace vllm {
namespace myq {
namespace mygq {
union half2_uint32
{
......@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i
return (int)(__funnelshift_rc(q0, q1, shift) & mask);
}
} // namespace myq
} // namespace mygq
} // namespace vllm
#endif
......@@ -339,7 +339,7 @@ vllm_extension_sources = [
"csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu",
"csrc/quantization/myq/q_gemm.cu",
"csrc/quantization/mygq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu",
"csrc/moe_align_block_size_kernels.cu",
"csrc/pybind.cpp",
......
......@@ -155,7 +155,7 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin","myq"]
supported_quantization = ["awq", "gptq", "squeezellm", "marlin","mygq"]
rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......
......@@ -208,7 +208,7 @@ class EngineArgs:
parser.add_argument('--quantization',
'-q',
type=str,
choices=['awq', 'gptq', 'squeezellm','myq', None],
choices=['awq', 'gptq', 'squeezellm','mygq', None],
default=EngineArgs.quantization,
help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` '
......
......@@ -3,13 +3,13 @@ from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.myq import MYQConfig
from vllm.model_executor.layers.quantization.mygq import MYQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig,
"myq": MYQConfig,
"mygq": MYQConfig,
"gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig,
......
......@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig):
@classmethod
def get_name(cls) -> str:
return "myq"
return "mygq"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
......@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase):
else:
weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY
ops.myq_shuffle(weights["qweight"], weights["g_idx"],
ops.mygq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits)
output = ops.myq_gemm(reshaped_x, weights["qweight"],
output = ops.mygq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"],
weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment