Commit f5f79f5c authored by chenxl's avatar chenxl
Browse files

[ADD] support multi-gpu qlen>1 q5_k

parent f2938031
......@@ -15,3 +15,6 @@ node_modules
compile_commands.json
*.egg-info*
*dist/
ktransformers/server/local_store/
ktransformers/server_test1.db
*.patch
\ No newline at end of file
@echo off
REM clear build dirs
rmdir /S /Q ktransformers\ktransformers_ext\build
rmdir /S /Q ktransformers\ktransformers_ext\cuda\build
rmdir /S /Q ktransformers\ktransformers_ext\cuda\dist
rmdir /S /Q ktransformers\ktransformers_ext\out
del /F /Q ktransformers\ktransformers_ext\cuda\*.egg-info
echo Installing python dependencies from requirements.txt
pip install -r requirements-local_chat.txt
echo Installing ktransformers
set KTRANSFORMERS_FORCE_BUILD=TRUE
pip install . --no-build-isolation
echo Installation completed successfully
\ No newline at end of file
......@@ -11,5 +11,5 @@ echo "Installing python dependencies from requirements.txt"
pip install -r requirements-local_chat.txt
echo "Installing ktransformers"
pip install . --no-build-isolation
KTRANSFORMERS_FORCE_BUILD=TRUE pip install . --no-build-isolation
echo "Installation completed successfully"
\ No newline at end of file
......@@ -189,7 +189,13 @@ else()
message(STATUS "Unknown architecture")
endif()
find_package(CUDA REQUIRED)
# message(STATUS "CUDAToolkit_ROOT:${CUDAToolkit_ROOT}")
# find_package(FindCUDAToolkit REQUIRED)
# if(CUDAToolkit_FOUND)
# message(STATUS "Found CUDA cudart lib at:${CUDAToolkit_LIBRARY_DIR}")
# else()
# message(STATUS "Can't found CUDA lib")
# endif()
add_compile_options("$<$<COMPILE_LANGUAGE:CXX>:${ARCH_FLAGS}>")
add_compile_options("$<$<COMPILE_LANGUAGE:C>:${ARCH_FLAGS}>")
......@@ -198,7 +204,12 @@ add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/pybind11 ${CMAKE_
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party/llama.cpp ${CMAKE_CURRENT_BINARY_DIR}/third_party/llama.cpp)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
include_directories("${CUDA_INCLUDE_DIRS}")
if (WIN32)
include_directories("$ENV{CUDA_PATH}/include")
elseif (UNIX)
find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR}/cpu_backend SOURCE_DIR2)
......@@ -209,4 +220,8 @@ message(STATUS "ALL_SOURCES: ${ALL_SOURCES}")
pybind11_add_module(${PROJECT_NAME} MODULE ${ALL_SOURCES})
target_link_libraries(${PROJECT_NAME} PRIVATE llama)
target_link_libraries(${PROJECT_NAME} PRIVATE "/usr/local/cuda/lib64/libcudart.so")
\ No newline at end of file
if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
\ No newline at end of file
......@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:33:47
* @LastEditors : chenxl
* @LastEditTime : 2024-08-08 04:23:51
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#ifndef CPUINFER_TASKQUEUE_H
......@@ -17,6 +17,44 @@
#include <queue>
#include <thread>
#include <vector>
#ifdef _WIN32
#include <windows.h>
#endif
class custom_mutex {
private:
#ifdef _WIN32
HANDLE global_mutex;
#else
std::mutex global_mutex;
#endif
public:
custom_mutex()
{
#ifdef _WIN32
HANDLE global_mutex;
#endif
}
void lock()
{
#ifdef _WIN32
WaitForSingleObject(global_mutex, INFINITE);
#else
global_mutex.lock();
#endif
}
void unlock()
{
#ifdef _WIN32
ReleaseMutex(global_mutex);
#else
global_mutex.lock();
#endif
}
};
class TaskQueue {
public:
......@@ -32,7 +70,7 @@ class TaskQueue {
std::queue<std::function<void()>> tasks;
std::thread worker;
std::mutex mutex;
custom_mutex mutex;
std::atomic<bool> sync_flag;
std::atomic<bool> exit_flag;
};
......
......@@ -3,8 +3,8 @@
* @Author : Azure-Tang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : Azure
* @LastEditTime : 2024-07-26 08:36:03
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-09 01:45:02
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
......@@ -23,6 +23,8 @@ PYBIND11_MODULE(KTransformersOps, m) {
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
......
......@@ -12,12 +12,15 @@ int test(){
}
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("test", &test, "Function to test.");
......
#include <cuda_fp16.h>
__device__ float ggml_compute_fp16_to_fp32(uint16_t h) {
return __uint2float_rd(h);
}
static inline float ggml_compute_fp16_to_fp32(uint16_t h) {
uint16_t tmp;
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
return (float)tmp;
}
// define the global table for fp16 to fp32 conversion
__device__ float ggml_table_f32_f16[1 << 16];
// CUDA Kernel to init the table
__global__ void init_fp16_to_fp32_table() {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto blk_id = idx; blk_id<(1 << 16); blk_id+=blockDim.x * gridDim.x){
ggml_table_f32_f16[blk_id] = GGML_COMPUTE_FP16_TO_FP32(blk_id);
}
}
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
extern __device__ float ggml_table_f32_f16[1 << 16]; // Declare as __device__ if used within device code
// This version of the function is designed to be called from within a CUDA kernel
#if !defined(GGML_FP16_TO_FP32)
__device__ float ggml_lookup_fp16_to_fp32(uint16_t f) {
return ggml_table_f32_f16[f];
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#endif
\ No newline at end of file
......@@ -3,8 +3,8 @@
* @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : Azure
* @LastEditTime : 2024-07-26 11:58:50
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-09 07:57:06
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
......@@ -14,6 +14,7 @@
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -59,6 +60,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);
for (int j = 0; j < 256; j += 64) {
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
}
}
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
......@@ -94,6 +124,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
// create gpu
auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
......@@ -128,6 +159,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
// data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
......@@ -147,6 +179,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) {
// data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
......@@ -162,3 +195,22 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}
\ No newline at end of file
......@@ -3,8 +3,8 @@
* @Author : Azure-Tang
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors : Azure
* @LastEditTime : 2024-07-26 08:38:20
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-09 01:44:21
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
......@@ -15,4 +15,5 @@
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
\ No newline at end of file
......@@ -23,7 +23,7 @@
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include <c10/cuda/CUDAGuard.h>
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
......@@ -1703,28 +1703,63 @@ void marlin_mm_f16i4(const void* A, const void* B, void* C, void* s,
thread_m_blocks = exec_cfg.max_m_blocks;
}
// Define kernel configurations
#define undefined_error TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " + \
str(prob_n) + ", " + str(prob_k) + "]" + \
", has_act_order = " + str(has_act_order) + \
", num_groups = " + str(num_groups) + \
", group_size = " + str(group_size) + \
", thread_m_blocks = " + str(thread_m_blocks) + \
", thread_n_blocks = " + str(thread_n_blocks) + \
", thread_k_blocks = " + str(thread_k_blocks));
if (num_bits == 4 && num_threads == 256)
{
if (false) {
}
CALL_IF(4, 32, 2, 256)
CALL_IF(4, 16, 4, 256)
CALL_IF(4, 8, 8, 256)
else {
undefined_error
}
}
else if (num_bits == 4 && num_threads == 128)
{
if (false) {
}
CALL_IF(4, 8, 4, 128)
CALL_IF(4, 4, 8, 128)
else {
undefined_error
}
}
else if (num_bits == 8 && num_threads == 256)
{
if (false) {
}
CALL_IF(8, 32, 2, 256)
CALL_IF(8, 16, 4, 256)
CALL_IF(8, 8, 8, 256)
else {
undefined_error
}
}
else if (num_bits == 8 && num_threads == 128)
{
if (false) {
}
CALL_IF(8, 8, 4, 128)
CALL_IF(8, 4, 8, 128)
else {
TORCH_CHECK(false, "Unsupported shapes: MNK = [" + str(prob_m) + ", " +
str(prob_n) + ", " + str(prob_k) + "]" +
", has_act_order = " + str(has_act_order) +
", num_groups = " + str(num_groups) +
", group_size = " + str(group_size) +
", thread_m_blocks = " + str(thread_m_blocks) +
", thread_n_blocks = " + str(thread_n_blocks) +
", thread_k_blocks = " + str(thread_k_blocks));
undefined_error
}
}
else {
undefined_error
}
A_ptr += 16 * thread_m_blocks * (prob_k / 8) * par;
......@@ -1739,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
......@@ -1781,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
......
......@@ -2,17 +2,25 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# setup marlin gemm
setup(name='KTransformersOps',
setup(
name='KTransformersOps',
ext_modules=[
CUDAExtension('KTransformersOps', [
CUDAExtension(
'KTransformersOps', [
'custom_gguf/dequant.cu',
'binding.cpp',
'gptq_marlin/gptq_marlin.cu',
# 'gptq_marlin_repack.cu',
])
],
cmdclass={'build_ext': BuildExtension
})
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
},
)
],
cmdclass={'build_ext': BuildExtension}
)
\ No newline at end of file
......@@ -37,7 +37,7 @@ class LinearBindings {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear);
}
static std::pair<intptr_t, intptr_t> interface(Linear& linear) {
static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear) {
Args* args = new Args{nullptr, &linear};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
......@@ -55,7 +55,7 @@ class LinearBindings {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> interface(Linear& linear, int qlen, intptr_t input, intptr_t output) {
static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &linear, qlen, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
......@@ -74,7 +74,7 @@ class MLPBindings {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp);
}
static std::pair<intptr_t, intptr_t> interface(MLP& mlp) {
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp) {
Args* args = new Args{nullptr, &mlp};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
......@@ -92,7 +92,7 @@ class MLPBindings {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) {
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &mlp, qlen, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
......@@ -111,7 +111,7 @@ class MOEBindings {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe);
}
static std::pair<intptr_t, intptr_t> interface(MOE& moe) {
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe) {
Args* args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
......@@ -132,7 +132,7 @@ class MOEBindings {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) {
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &moe, qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
......@@ -154,8 +154,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
}));
py::class_<Linear>(linear_module, "Linear")
.def(py::init<LinearConfig>())
.def("warm_up", &LinearBindings::WarmUpBindinds::interface)
.def("forward", &LinearBindings::ForwardBindings::interface);
.def("warm_up", &LinearBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &LinearBindings::ForwardBindings::cpuinfer_interface);
auto mlp_module = m.def_submodule("mlp");
py::class_<MLPConfig>(mlp_module, "MLPConfig")
......@@ -164,8 +164,8 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
}));
py::class_<MLP>(mlp_module, "MLP")
.def(py::init<MLPConfig>())
.def("warm_up", &MLPBindings::WarmUpBindinds::interface)
.def("forward", &MLPBindings::ForwardBindings::interface);
.def("warm_up", &MLPBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MLPBindings::ForwardBindings::cpuinfer_interface);
auto moe_module = m.def_submodule("moe");
py::class_<MOEConfig>(moe_module, "MOEConfig")
......@@ -174,6 +174,6 @@ PYBIND11_MODULE(cpuinfer_ext, m) {
}));
py::class_<MOE>(moe_module, "MOE")
.def(py::init<MOEConfig>())
.def("warm_up", &MOEBindings::WarmUpBindinds::interface)
.def("forward", &MOEBindings::ForwardBindings::interface);
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
}
import math
import os
import time
from logging import getLogger
import torch
import torch.nn as nn
import transformers
from .quantizer import Quantizer
logger = getLogger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class GPTQ:
def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.pytorch_utils.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.quantizer = Quantizer()
def add_batch(self, inp, out):
if os.environ.get("DEBUG"):
self.inp1 = inp
self.out1 = out
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def fasterquant(
self,
blocksize=128,
percdamp=0.01,
group_size=-1,
actorder=False,
static_groups=False,
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()
tick = time.time()
if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
g_idx = []
scale = []
zero = []
now_idx = 1
if static_groups:
import copy
groups = []
for i in range(0, self.columns, group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i : (i + group_size)], weight=True)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if group_size != -1:
if not static_groups:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + group_size)], weight=True)
if ((i1 + i) // group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // group_size]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if os.environ.get("DEBUG"):
self.layer.weight.data[:, :i2] = Q[:, :i2]
self.layer.weight.data[:, i2:] = W[:, i2:]
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))
torch.cuda.synchronize()
logger.info(f"duration: {(time.time() - tick)}")
logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}")
group_size = group_size if group_size != -1 else self.columns
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx
def free(self):
if os.environ.get("DEBUG"):
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
__all__ = ["GPTQ"]
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def get_pack_factor(num_bits: int):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference")
return None
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Validate dtype
if params_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(f"The params dtype must be float16 "
f"or bfloat16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_thread_n != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {self.quant_config.min_thread_n}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = torch.empty(
g_idx.shape,
dtype=torch.int32,
)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.g_idx_sort_indices = g_idx_sort_indices
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.marlin_state = GPTQMarlinState.REPACK
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
from logging import getLogger
import torch
import torch.nn as nn
logger = getLogger(__name__)
def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
__all__ = ["Quantizer"]
import torch
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
from torch.nn.parameter import Parameter
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
......@@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
def __init__(self, out_features, min_thread_n, max_parallel, device):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))
......@@ -229,4 +229,4 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")
device=device)
......@@ -47,13 +47,13 @@ void Linear::forward_many(int qlen, const void* input, void* output, Backend* ba
int nth = config_.output_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id;
void* proj_ptr = proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);
void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);
float* proj_output_ptr = proj_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
for (int i = 0; i < qlen; i++) {
float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride;
void* output_ptr = output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
void* output_ptr = (uint8_t*)output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
......@@ -69,5 +69,5 @@ void Linear::forward(int qlen, const void* input, void* output, Backend* backend
}
int forward_len = std::min(qlen, config_.group_max_len);
forward_many(forward_len, input, output, backend);
forward(qlen - forward_len, input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
}
\ No newline at end of file
......@@ -74,10 +74,10 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
int nth = config_.intermediate_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id;
void* gate_proj_ptr = gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
float* gate_output_ptr = gate_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_proj_ptr = up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
float* up_output_ptr = up_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < qlen; i++) {
......@@ -86,7 +86,7 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
}
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride;
void* down_input_ptr = down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
void* down_input_ptr = (uint8_t*)down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}
......@@ -97,13 +97,13 @@ void MLP::forward_many(int qlen, const void* input, void* output, Backend* backe
nth = config_.hidden_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id;
void* down_proj_ptr = down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = down_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
for (int i = 0; i < qlen; i++) {
float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride;
void* output_ptr = output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
void* output_ptr = (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
......@@ -119,5 +119,5 @@ void MLP::forward(int qlen, const void* input, void* output, Backend* backend) {
}
int forward_len = std::min(qlen, config_.group_max_len);
forward_many(forward_len, input, output, backend);
forward(qlen - forward_len, input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment