"...composable_kernel_onnxruntime.git" did not exist on "19a93dac051f3b5200fe00151b8fa5994aa890dd"
Commit eb8e460c authored by nicodafagood's avatar nicodafagood
Browse files

update mygq

parent 23fdbb68
...@@ -92,7 +92,7 @@ if __name__ == '__main__': ...@@ -92,7 +92,7 @@ if __name__ == '__main__':
parser.add_argument('--tokenizer', type=str, default=None) parser.add_argument('--tokenizer', type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq','myq', 'squeezellm', None], choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None) default=None)
parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1) parser.add_argument('--tensor-parallel-size', '-tp', type=int, default=1)
parser.add_argument('--input-len', type=int, default=32) parser.add_argument('--input-len', type=int, default=32)
......
...@@ -258,7 +258,7 @@ if __name__ == "__main__": ...@@ -258,7 +258,7 @@ if __name__ == "__main__":
parser.add_argument("--tokenizer", type=str, default=None) parser.add_argument("--tokenizer", type=str, default=None)
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
choices=['awq', 'gptq','myq', 'squeezellm', None], choices=['awq', 'gptq','mygq', 'squeezellm', None],
default=None) default=None)
parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1) parser.add_argument("--tensor-parallel-size", "-tp", type=int, default=1)
parser.add_argument("--n", parser.add_argument("--n",
......
...@@ -115,16 +115,16 @@ void gptq_shuffle( ...@@ -115,16 +115,16 @@ void gptq_shuffle(
torch::Tensor q_perm, torch::Tensor q_perm,
int bit); int bit);
torch::Tensor myq_gemm( torch::Tensor mygq_gemm(
torch::Tensor a, torch::Tensor a,
torch::Tensor b_q_weight, torch::Tensor b_q_weight,
torch::Tensor b_myq_qzeros, torch::Tensor b_mygq_qzeros,
torch::Tensor b_myq_scales, torch::Tensor b_mygq_scales,
torch::Tensor b_g_idx, torch::Tensor b_g_idx,
bool use_exllama, bool use_exllama,
int bit); int bit);
void myq_shuffle( void mygq_shuffle(
torch::Tensor q_weight, torch::Tensor q_weight,
torch::Tensor q_perm, torch::Tensor q_perm,
int bit); int bit);
......
...@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -61,8 +61,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ"); ops.def("gptq_gemm", &gptq_gemm, "Quantized GEMM for GPTQ");
ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ"); ops.def("gptq_shuffle", &gptq_shuffle, "Post processing for GPTQ");
ops.def("myq_gemm", &myq_gemm, "Quantized GEMM for myq"); ops.def("mygq_gemm", &mygq_gemm, "Quantized GEMM for mygq");
ops.def("myq_shuffle", &myq_shuffle, "Post processing for GPTQ"); ops.def("mygq_shuffle", &mygq_shuffle, "Post processing for mygq");
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM"); ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def( ops.def(
"moe_align_block_size", "moe_align_block_size",
......
...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _compat_cuh #define _compat_cuh
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// atomicAdd for half types, to support CC < 7.x // atomicAdd for half types, to support CC < 7.x
__device__ __forceinline__ void atomicAdd_half(half* address, half val) __device__ __forceinline__ void atomicAdd_half(half* address, half val)
...@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd ...@@ -59,6 +59,6 @@ __device__ __forceinline__ void atomicAdd(half2* address, half2 val) { atomicAdd
#endif #endif
#endif #endif
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo ...@@ -11,7 +11,7 @@ Adapted from https://github.com/turboderp/exllamav2 and https://github.com/turbo
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
class MatrixView_half class MatrixView_half
{ {
...@@ -269,6 +269,6 @@ public: ...@@ -269,6 +269,6 @@ public:
} }
}; };
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// Permutation: // Permutation:
// //
...@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16 ...@@ -81,7 +81,7 @@ __forceinline__ __device__ void dequant_2bit_16
dq[7] = __hfma2(q7.as_half2, y64, z64); dq[7] = __hfma2(q7.as_half2, y64, z64);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// Permutation: // Permutation:
// //
// v9997775 55333111 u8886664 44222000 (u, v lsb) // v9997775 55333111 u8886664 44222000 (u, v lsb)
...@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32 ...@@ -135,7 +135,7 @@ __forceinline__ __device__ void dequant_3bit_32
dq[15] = __hadd2(q15.as_half2, z1); dq[15] = __hadd2(q15.as_half2, z1);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
// Permutation: // Permutation:
// //
// 77775555 33331111 66664444 22220000 // 77775555 33331111 66664444 22220000
...@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero ...@@ -107,7 +107,7 @@ __forceinline__ __device__ void dequant_4bit_8_prep_zero
} }
__forceinline__ __device__ void dequant_4bit_8_myq __forceinline__ __device__ void dequant_4bit_8_mygq
( (
const uint32_t q_0, const uint32_t q_0,
half2 (&dq)[4], half2 (&dq)[4],
...@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq ...@@ -141,7 +141,7 @@ __forceinline__ __device__ void dequant_4bit_8_myq
dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z ) dq[3] = __hfma2(q3.as_half2, y1y16[1], z1z16[1]); // half2( q[6] - z, q[7] - z )
} }
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -8,7 +8,7 @@ Copied from https://github.com/turboderp/exllamav2
#include "qdq_util.cuh" #include "qdq_util.cuh"
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
__forceinline__ __device__ void shuffle_8bit_4 __forceinline__ __device__ void shuffle_8bit_4
( (
...@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8 ...@@ -34,7 +34,7 @@ __forceinline__ __device__ void dequant_8bit_8
for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]); for (int i = 0; i < 4; i++) dq[i] = __halves2half2(dqh[i * 2], dqh[i * 2 + 1]);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2 ...@@ -6,7 +6,7 @@ Copied from https://github.com/turboderp/exllamav2
#define _qdq_util_cuh #define _qdq_util_cuh
namespace vllm { namespace vllm {
namespace myq { namespace mygq {
union half2_uint32 union half2_uint32
{ {
...@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i ...@@ -55,6 +55,6 @@ __forceinline__ __device__ int exb(const uint32_t q1, const uint32_t q0, const i
return (int)(__funnelshift_rc(q0, q1, shift) & mask); return (int)(__funnelshift_rc(q0, q1, shift) & mask);
} }
} // namespace myq } // namespace mygq
} // namespace vllm } // namespace vllm
#endif #endif
...@@ -339,7 +339,7 @@ vllm_extension_sources = [ ...@@ -339,7 +339,7 @@ vllm_extension_sources = [
"csrc/layernorm_kernels.cu", "csrc/layernorm_kernels.cu",
"csrc/quantization/squeezellm/quant_cuda_kernel.cu", "csrc/quantization/squeezellm/quant_cuda_kernel.cu",
"csrc/quantization/gptq/q_gemm.cu", "csrc/quantization/gptq/q_gemm.cu",
"csrc/quantization/myq/q_gemm.cu", "csrc/quantization/mygq/q_gemm.cu",
"csrc/cuda_utils_kernels.cu", "csrc/cuda_utils_kernels.cu",
"csrc/moe_align_block_size_kernels.cu", "csrc/moe_align_block_size_kernels.cu",
"csrc/pybind.cpp", "csrc/pybind.cpp",
......
...@@ -155,7 +155,7 @@ class ModelConfig: ...@@ -155,7 +155,7 @@ class ModelConfig:
self.tokenizer_mode = tokenizer_mode self.tokenizer_mode = tokenizer_mode
def _verify_quantization(self) -> None: def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin","myq"] supported_quantization = ["awq", "gptq", "squeezellm", "marlin","mygq"]
rocm_not_supported_quantization = ["awq", "marlin"] rocm_not_supported_quantization = ["awq", "marlin"]
if self.quantization is not None: if self.quantization is not None:
self.quantization = self.quantization.lower() self.quantization = self.quantization.lower()
......
...@@ -208,7 +208,7 @@ class EngineArgs: ...@@ -208,7 +208,7 @@ class EngineArgs:
parser.add_argument('--quantization', parser.add_argument('--quantization',
'-q', '-q',
type=str, type=str,
choices=['awq', 'gptq', 'squeezellm','myq', None], choices=['awq', 'gptq', 'squeezellm','mygq', None],
default=EngineArgs.quantization, default=EngineArgs.quantization,
help='Method used to quantize the weights. If ' help='Method used to quantize the weights. If '
'None, we first check the `quantization_config` ' 'None, we first check the `quantization_config` '
......
...@@ -3,13 +3,13 @@ from typing import Type ...@@ -3,13 +3,13 @@ from typing import Type
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.quantization.awq import AWQConfig
from vllm.model_executor.layers.quantization.gptq import GPTQConfig from vllm.model_executor.layers.quantization.gptq import GPTQConfig
from vllm.model_executor.layers.quantization.myq import MYQConfig from vllm.model_executor.layers.quantization.mygq import MYQConfig
from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig from vllm.model_executor.layers.quantization.squeezellm import SqueezeLLMConfig
from vllm.model_executor.layers.quantization.marlin import MarlinConfig from vllm.model_executor.layers.quantization.marlin import MarlinConfig
_QUANTIZATION_CONFIG_REGISTRY = { _QUANTIZATION_CONFIG_REGISTRY = {
"awq": AWQConfig, "awq": AWQConfig,
"myq": MYQConfig, "mygq": MYQConfig,
"gptq": GPTQConfig, "gptq": GPTQConfig,
"squeezellm": SqueezeLLMConfig, "squeezellm": SqueezeLLMConfig,
"marlin": MarlinConfig, "marlin": MarlinConfig,
......
...@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig): ...@@ -41,7 +41,7 @@ class MYQConfig(QuantizationConfig):
@classmethod @classmethod
def get_name(cls) -> str: def get_name(cls) -> str:
return "myq" return "mygq"
@classmethod @classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]: def get_supported_act_dtypes(cls) -> List[torch.dtype]:
...@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase): ...@@ -201,9 +201,9 @@ class MYQLinearMethod(LinearMethodBase):
else: else:
weights["g_idx"] = torch.empty((1, 1), device="meta") weights["g_idx"] = torch.empty((1, 1), device="meta")
weights["exllama_state"] = ExllamaState.READY weights["exllama_state"] = ExllamaState.READY
ops.myq_shuffle(weights["qweight"], weights["g_idx"], ops.mygq_shuffle(weights["qweight"], weights["g_idx"],
self.quant_config.weight_bits) self.quant_config.weight_bits)
output = ops.myq_gemm(reshaped_x, weights["qweight"], output = ops.mygq_gemm(reshaped_x, weights["qweight"],
weights["qzeros"], weights["scales"], weights["qzeros"], weights["scales"],
weights["g_idx"], weights["g_idx"],
weights["exllama_state"] == ExllamaState.READY, weights["exllama_state"] == ExllamaState.READY,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment