Unverified Commit 77a34c28 authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #36 from kvcache-ai/develop-0.1.2

Release v0.1.2
parents 44f57270 395cd3e7
......@@ -12,14 +12,22 @@ int test(){
}
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("test", &test, "Function to test.");
}
#include <cuda_fp16.h>
__device__ float ggml_compute_fp16_to_fp32(uint16_t h) {
return __uint2float_rd(h);
}
static inline float ggml_compute_fp16_to_fp32(uint16_t h) {
uint16_t tmp;
memcpy(&tmp, &h, sizeof(ggml_fp16_t));
return (float)tmp;
}
// define the global table for fp16 to fp32 conversion
__device__ float ggml_table_f32_f16[1 << 16];
// CUDA Kernel to init the table
__global__ void init_fp16_to_fp32_table() {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto blk_id = idx; blk_id<(1 << 16); blk_id+=blockDim.x * gridDim.x){
ggml_table_f32_f16[blk_id] = GGML_COMPUTE_FP16_TO_FP32(blk_id);
}
}
#define GGML_COMPUTE_FP16_TO_FP32(x) ggml_compute_fp16_to_fp32(x)
extern __device__ float ggml_table_f32_f16[1 << 16]; // Declare as __device__ if used within device code
// This version of the function is designed to be called from within a CUDA kernel
#if !defined(GGML_FP16_TO_FP32)
__device__ float ggml_lookup_fp16_to_fp32(uint16_t f) {
return ggml_table_f32_f16[f];
}
#define GGML_FP16_TO_FP32(x) ggml_lookup_fp16_to_fp32(x)
#endif
\ No newline at end of file
......@@ -3,8 +3,8 @@
* @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30
* @Version : 1.0.0
* @LastEditors : Azure
* @LastEditTime : 2024-07-26 11:58:50
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 04:18:04
* Adapted from https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c
* Copyright (c) 2023-2024 The ggml authors
* Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
......@@ -14,6 +14,7 @@
#include <torch/extension.h>
#include <torch/torch.h>
#include <cstdint>
#include <c10/cuda/CUDAGuard.h>
__global__ void dequantize_q8_0_kernel(float* output, const float* scales, const int8_t* qs, int num_blocks, int blk_size) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
......@@ -35,6 +36,97 @@ __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t * __restrict_
}
}
__global__ void dequantize_q2_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 80)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 82)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 16);
int is = 0;
float dl, ml;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + (is++));
uint8_t sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l] >> shift) & 3)) - ml;
scales = (uint8_t*)(data + block_id * blk_size + (is++));
sc = *scales;
dl = d * (sc & 0xF); ml = min * (sc >> 4);
for (int l = 0; l < 16; ++l) *output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3)) - ml;
shift += 2;
}
q += 32;
}
}
}
__global__ void dequantize_q3_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
uint32_t aux[4];
const int8_t * scales = (const int8_t*)aux;
const float d_all = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 108)));
const uint8_t * __restrict__ q = (uint8_t*)(data + block_id * blk_size + 32);
const uint8_t * __restrict__ hm = (uint8_t*)(data + block_id * blk_size + 0);
uint8_t m = 1;
uint8_t* block_scales = (uint8_t*)(data + block_id * blk_size + 96);
for (int i = 0; i < 3; i++) {
aux[i] = 0;
for (int j = 0; j < 4; j++) {
aux[i] |= ((uint32_t)block_scales[i * 4 + j]) << (j * 8);
}
}
uint32_t tmp = aux[2];
aux[2] = ((aux[0] >> 4) & kmask2) | (((tmp >> 4) & kmask1) << 4);
aux[3] = ((aux[1] >> 4) & kmask2) | (((tmp >> 6) & kmask1) << 4);
aux[0] = (aux[0] & kmask2) | (((tmp >> 0) & kmask1) << 4);
aux[1] = (aux[1] & kmask2) | (((tmp >> 2) & kmask1) << 4);
int is = 0;
float dl;
for (int n = 0; n < 256; n += 128) {
int shift = 0;
for (int j = 0; j < 4; ++j) {
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = dl * ((int8_t)((q[l+ 0] >> shift) & 3) - ((hm[l+ 0] & m) ? 0 : 4));
}
dl = d_all * (scales[is++] - 32);
for (int l = 0; l < 16; ++l) {
*output_blk++ = dl * ((int8_t)((q[l+16] >> shift) & 3) - ((hm[l+16] & m) ? 0 : 4));
}
shift += 2;
m <<= 1;
}
q += 32;
}
}
}
__global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
......@@ -59,6 +151,35 @@ __global__ void dequantize_q4_k_kernel(int8_t* data, float* output, int blk_size
}
}
__global__ void dequantize_q5_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks; block_id+= blockDim.x * gridDim.x){
float* __restrict__ output_blk = (float*)(output + block_id * 256);
const float d = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 0)));
const float min = __half2float(*(reinterpret_cast<half*>(data + block_id * blk_size + 2)));
const uint8_t * __restrict__ qh = (uint8_t*)(data + block_id * blk_size + 16);
const uint8_t * __restrict__ ql = (uint8_t*)(data + block_id * blk_size + 48);
int is = 0;
uint8_t sc, m;
uint8_t u1 = 1, u2 = 2;
uint8_t* scales = (uint8_t*)(data + block_id * blk_size + 4);
for (int j = 0; j < 256; j += 64) {
get_scale_min_k4(is + 0, scales, &sc, &m);
const float d1 = d * sc; const float m1 = min * m;
get_scale_min_k4(is + 1, scales, &sc, &m);
const float d2 = d * sc; const float m2 = min * m;
for (int l = 0; l < 32; ++l) *output_blk++ = d1 * ((ql[l] & 0xF) + (qh[l] & u1 ? 16 : 0)) - m1;
for (int l = 0; l < 32; ++l) *output_blk++ = d2 * ((ql[l] >> 4) + (qh[l] & u2 ? 16 : 0)) - m2;
ql += 32; is += 2;
u1 <<= 2; u2 <<= 2;
}
}
}
__global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size, int num_blocks) {
int global_idx = blockIdx.x * blockDim.x + threadIdx.x;
for (auto block_id=global_idx; block_id<num_blocks;block_id+=blockDim.x * gridDim.x){
......@@ -94,6 +215,7 @@ __global__ void dequantize_q6_k_kernel(int8_t* data, float* output, int blk_size
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
// create gpu
auto options_scales = torch::TensorOptions().dtype(torch::kFloat32).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto options_qs = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
......@@ -128,6 +250,7 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
// data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
......@@ -144,9 +267,28 @@ torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device de
return output;
}
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q5_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device) {
// data.numel%blk_size should be 0, else raise err
int num_blocks = data.numel() / blk_size;
const at::cuda::OptionalCUDAGuard device_guard(device);
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
......@@ -162,3 +304,39 @@ torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device de
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q3_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device) {
int num_blocks = data.numel() / blk_size;
auto options = torch::TensorOptions().dtype(torch::kInt8).device(device).memory_format(torch::MemoryFormat::Contiguous);
auto data_gpu = torch::empty({data.numel()}, options);
data_gpu.copy_(data, false);
// Create output tensor
auto output = torch::zeros({num_blocks, 256}, torch::dtype(torch::kFloat32).device(device));
// Launch kernel
dequantize_q2_k_kernel<<< 512, 256 >>>(data_gpu.data_ptr<int8_t>(), output.data_ptr<float>(), blk_size, num_blocks);
cudaDeviceSynchronize();
return output;
}
\ No newline at end of file
......@@ -3,8 +3,8 @@
* @Author : Azure-Tang
* @Date : 2024-07-22 09:27:55
* @Version : 1.0.0
* @LastEditors : Azure
* @LastEditTime : 2024-07-26 08:38:20
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:48:46
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#pragma once
......@@ -15,4 +15,7 @@
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
\ No newline at end of file
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
\ No newline at end of file
......@@ -23,7 +23,7 @@
*/
#include "gptq_marlin.cuh"
#include "gptq_marlin_dtypes.cuh"
#include <c10/cuda/CUDAGuard.h>
#define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \
static_assert(std::is_same<scalar_t, half>::value || \
std::is_same<scalar_t, nv_bfloat16>::value, \
......@@ -1774,6 +1774,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
torch::Tensor& perm, torch::Tensor& workspace,
int64_t num_bits, int64_t size_m, int64_t size_n,
int64_t size_k, bool is_k_full) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
// Verify num_bits
TORCH_CHECK(num_bits == 4 || num_bits == 8,
"num_bits must be 4 or 8. Got = ", num_bits);
......@@ -1816,7 +1817,6 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight,
TORCH_CHECK(perm.is_contiguous(), "perm is not contiguous");
// Alloc buffers
const at::cuda::OptionalCUDAGuard device_guard(device_of(a));
auto options = torch::TensorOptions().dtype(a.dtype()).device(a.device());
torch::Tensor c = torch::empty({size_m, size_n}, options);
torch::Tensor a_tmp = torch::empty({size_m, size_k}, options);
......
......@@ -2,17 +2,25 @@
from setuptools import setup, Extension
from torch.utils import cpp_extension
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
# setup marlin gemm
setup(name='KTransformersOps',
ext_modules=[
CUDAExtension('KTransformersOps', [
setup(
name='KTransformersOps',
ext_modules=[
CUDAExtension(
'KTransformersOps', [
'custom_gguf/dequant.cu',
'binding.cpp',
'gptq_marlin/gptq_marlin.cu',
# 'gptq_marlin_repack.cu',
])
],
cmdclass={'build_ext': BuildExtension
})
# 'gptq_marlin_repack.cu',
],
extra_compile_args={
'cxx': ['-O3'],
'nvcc': [
'-O3',
'--use_fast_math',
'-Xcompiler', '-fPIC',
]
},
)
],
cmdclass={'build_ext': BuildExtension}
)
\ No newline at end of file
......@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-07-25 10:34:00
LastEditTime : 2024-08-06 10:36:59
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
......@@ -15,23 +15,23 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
with torch.inference_mode(mode=True):
input_size = 16384
output_size = 5120
stride = 32
proj_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
validation_iter = 100
warm_up_iter = 1000
test_iter = 10000
input_size = 16384
output_size = 5120
stride = 32
group_max_len = 1024
proj_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
qlen = 30
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
validation_iter = 100
with torch.inference_mode(mode=True):
linears = []
projs = []
for _ in range(layer_num):
proj = torch.randn((output_size, input_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, proj.data_ptr(), proj_type, hidden_type)
config = cpuinfer_ext.linear.LinearConfig(input_size, output_size, stride, group_max_len, proj.data_ptr(), proj_type, hidden_type)
linear = cpuinfer_ext.linear.Linear(config)
projs.append(proj)
linears.append(linear)
......@@ -39,11 +39,17 @@ with torch.inference_mode(mode=True):
# validation
for i in range(validation_iter):
linear = linears[i % layer_num]
input = torch.randn((1, input_size), dtype=torch.float16).contiguous()
output = torch.empty((1, output_size), dtype=torch.float16).contiguous()
input = torch.randn((qlen, input_size), dtype=torch.float16).contiguous()
output = torch.empty((qlen, output_size), dtype=torch.float16).contiguous()
input = input / 100
CPUInfer.submit(linear.forward, input.data_ptr(), output.data_ptr())
CPUInfer.submit(
linear.forward(
qlen,
input.data_ptr(),
output.data_ptr()
)
)
CPUInfer.sync()
# print('cpuinfer output', output)
......@@ -54,30 +60,3 @@ with torch.inference_mode(mode=True):
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print('diff = ', diff)
assert(diff < 0.001)
# warm up
for i in range(warm_up_iter):
linear = linears[i % layer_num]
input = torch.randn((1, input_size), dtype=torch.float16).contiguous()
output = torch.empty((1, output_size), dtype=torch.float16).contiguous()
input = input / 100
CPUInfer.submit(linear.forward, input.data_ptr(), output.data_ptr())
CPUInfer.sync()
# test
total_time = 0
for i in range(test_iter):
linear = linears[i % layer_num]
input = torch.randn((1, input_size), dtype=torch.float16).contiguous()
output = torch.empty((1, output_size), dtype=torch.float16).contiguous()
input = input / 100
start = time.perf_counter()
CPUInfer.submit(linear.forward, input.data_ptr(), output.data_ptr())
CPUInfer.sync()
end = time.perf_counter()
total_time += end - start
print('Time: ', total_time)
print('Iteration: ', test_iter)
print('Time per iteration: ', total_time / test_iter)
print('Bandwidth: ', input_size * output_size * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print("All tasks completed.")
\ No newline at end of file
......@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-07-25 10:34:03
LastEditTime : 2024-08-06 10:37:28
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
......@@ -15,20 +15,30 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
with torch.inference_mode(mode=True):
hidden_size = 5120
intermediate_size = 3072
stride = 32
gate_type = 1 # ggml_type::GGML_TYPE_F16
up_type = 1 # ggml_type::GGML_TYPE_F16
down_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
validation_iter = 100
warm_up_iter = 1000
test_iter = 10000
hidden_size = 5120
intermediate_size = 3072
stride = 32
group_max_len = 1024
gate_type = 1 # ggml_type::GGML_TYPE_F16
up_type = 1 # ggml_type::GGML_TYPE_F16
down_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
qlen = 30
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
validation_iter = 100
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
return ret
with torch.inference_mode(mode=True):
mlps = []
gate_projs = []
up_projs = []
......@@ -37,7 +47,7 @@ with torch.inference_mode(mode=True):
gate_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
up_proj = torch.randn((intermediate_size, hidden_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
down_proj = torch.randn((hidden_size, intermediate_size), dtype=torch.float16, device = "cuda").to("cpu").contiguous()
config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
config = cpuinfer_ext.mlp.MLPConfig(hidden_size, intermediate_size, stride, group_max_len, gate_proj.data_ptr(), up_proj.data_ptr(), down_proj.data_ptr(), gate_type, up_type, down_type, hidden_type)
mlp = cpuinfer_ext.mlp.MLP(config)
gate_projs.append(gate_proj)
up_projs.append(up_proj)
......@@ -47,52 +57,26 @@ with torch.inference_mode(mode=True):
# validation
for i in range(validation_iter):
mlp = mlps[i % layer_num]
input = torch.randn((1, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((1, hidden_size), dtype=torch.float16).contiguous()
input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()
input = input / 100
CPUInfer.submit(mlp.forward, input.data_ptr(), output.data_ptr())
CPUInfer.submit(
mlp.forward(
qlen,
input.data_ptr(),
output.data_ptr()
)
)
CPUInfer.sync()
# print('cpuinfer output', output)
def act_fn(x):
return x / (1.0 + torch.exp(-x))
gate_proj = gate_projs[i%layer_num]
up_proj = up_projs[i%layer_num]
down_proj = down_projs[i%layer_num]
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
t_output = torch.mm(intermediate, down_proj.t())
t_output = mlp_torch(input, gate_proj, up_proj, down_proj)
# print('torch output', t_output)
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print('diff = ', diff)
assert(diff < 0.001)
# warm up
for i in range(warm_up_iter):
mlp = mlps[i % layer_num]
input = torch.randn((1, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((1, hidden_size), dtype=torch.float16).contiguous()
input = input / 100
CPUInfer.submit(mlp.forward, input.data_ptr(), output.data_ptr())
CPUInfer.sync()
# test
total_time = 0
for i in range(test_iter):
mlp = mlps[i % layer_num]
input = torch.randn((1, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((1, hidden_size), dtype=torch.float16).contiguous()
input = input / 100
start = time.time()
CPUInfer.submit(mlp.forward, input.data_ptr(), output.data_ptr())
CPUInfer.sync()
end = time.time()
total_time += end - start
print('Time: ', total_time)
print('Iteration: ', test_iter)
print('Time per iteration: ', total_time / test_iter)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * 2 * test_iter / total_time / 1024 / 1024 / 1024, 'GB/s')
print("All tasks completed.")
\ No newline at end of file
......@@ -6,7 +6,7 @@ Author : chenht2022
Date : 2024-07-25 10:32:05
Version : 1.0.0
LastEditors : chenht2022
LastEditTime : 2024-07-25 10:34:06
LastEditTime : 2024-08-06 10:38:05
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import os, sys
......@@ -15,25 +15,64 @@ sys.path.append(os.path.dirname(__file__) + '/../build')
import cpuinfer_ext
import torch
with torch.inference_mode(mode=True):
expert_num = 10
hidden_size = 5120
intermediate_size = 1536
stride = 32
group_min_len = 10
group_max_len = 1024
gate_type = 1 # ggml_type::GGML_TYPE_F16
up_type = 1 # ggml_type::GGML_TYPE_F16
down_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
n_routed_experts = 6
qlen = 30
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
validation_iter = 100
warm_up_iter = 1000
test_iter = 10000
expert_num = 160
hidden_size = 5120
intermediate_size = 1536
stride = 32
group_min_len = 10
group_max_len = 1024
gate_type = 1 # ggml_type::GGML_TYPE_F16
up_type = 1 # ggml_type::GGML_TYPE_F16
down_type = 1 # ggml_type::GGML_TYPE_F16
hidden_type = 1 # ggml_type::GGML_TYPE_F16
n_routed_experts = 6
qlen = 30
layer_num = 10
CPUInfer = cpuinfer_ext.CPUInfer(48)
validation_iter = 100
def act_fn(x):
return x / (1.0 + torch.exp(-x))
def mlp_torch(input, gate_proj, up_proj, down_proj):
gate_buf = torch.mm(input, gate_proj.t())
up_buf = torch.mm(input, up_proj.t())
intermediate = act_fn(gate_buf) * up_buf
ret = torch.mm(intermediate, down_proj.t())
return ret
def moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj):
cnts = expert_ids.new_zeros((expert_ids.shape[0], expert_num))
cnts.scatter_(1, expert_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = expert_ids.view(-1).argsort()
sorted_tokens = input[idxs // expert_ids.shape[1]]
outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
expert_out = mlp_torch(tokens_for_this_expert, gate_proj[i], up_proj[i], down_proj[i])
outputs.append(expert_out)
start_idx = end_idx
outs = torch.cat(outputs, dim=0) if len(outputs) else sorted_tokens.new_empty(0)
new_x = torch.empty_like(outs)
new_x[idxs] = outs
t_output = (
new_x.view(*expert_ids.shape, -1)
.type(weights.dtype)
.mul_(weights.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)
return t_output
with torch.inference_mode(mode=True):
moes = []
gate_projs = []
up_projs = []
......@@ -51,63 +90,32 @@ with torch.inference_mode(mode=True):
# validation
for i in range(validation_iter):
moe = moes[i % layer_num]
expert_ids = torch.randint(0, expert_num, (qlen, n_routed_experts), dtype=torch.int64).contiguous()
expert_ids = torch.stack([torch.randperm(expert_num)[:n_routed_experts] for _ in range(qlen)]).contiguous()
weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()
input = torch.randn((qlen, 1, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((qlen, 1, hidden_size), dtype=torch.float16).contiguous()
input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()
input = input / 100
CPUInfer.submit(moe.forward, qlen, n_routed_experts, expert_ids.data_ptr(), weights.data_ptr(), input.data_ptr(), output.data_ptr())
moe = moes[i % layer_num]
CPUInfer.submit(
moe.forward(
qlen,
n_routed_experts,
expert_ids.data_ptr(),
weights.data_ptr(),
input.data_ptr(),
output.data_ptr()
)
)
CPUInfer.sync()
# print('cpuinfer output', output)
def act_fn(x):
return x / (1.0 + torch.exp(-x))
t_output = torch.zeros((qlen, 1, hidden_size), dtype=torch.float32).contiguous()
gate_proj = gate_projs[i%layer_num]
up_proj = up_projs[i%layer_num]
down_proj = down_projs[i%layer_num]
for token_idx in range(qlen):
for i, expert_id in enumerate(expert_ids[token_idx]):
gate_buf = torch.mm(input[token_idx], gate_proj[expert_id].t())
up_buf = torch.mm(input[token_idx], up_proj[expert_id].t())
intermediate = act_fn(gate_buf) * up_buf
expert_output = torch.mm(intermediate, down_proj[expert_id].t())
t_output[token_idx] += weights[token_idx][i] * expert_output
t_output = moe_torch(input, expert_ids, weights, gate_proj, up_proj, down_proj)
# print('torch output', t_output)
diff = torch.mean(torch.abs(output - t_output)) / torch.mean(torch.abs(t_output))
print('diff = ', diff)
assert(diff < 0.001)
# warm up
for i in range(warm_up_iter):
moe = moes[i % layer_num]
expert_ids = torch.randint(0, expert_num, (qlen, n_routed_experts), dtype=torch.int64).contiguous()
weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()
input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()
input = input / 100
CPUInfer.submit(moe.forward, qlen, n_routed_experts, expert_ids.data_ptr(), weights.data_ptr(), input.data_ptr(), output.data_ptr())
CPUInfer.sync()
# test
total_time = 0
for i in range(test_iter):
moe = moes[i % layer_num]
expert_ids = torch.randint(0, expert_num, (qlen, n_routed_experts), dtype=torch.int64).contiguous()
weights = torch.rand((qlen, n_routed_experts), dtype=torch.float32).contiguous()
input = torch.randn((qlen, hidden_size), dtype=torch.float16).contiguous()
output = torch.empty((qlen, hidden_size), dtype=torch.float16).contiguous()
input = input / 100
start = time.perf_counter()
CPUInfer.submit(moe.forward, qlen, n_routed_experts, expert_ids.data_ptr(), weights.data_ptr(), input.data_ptr(), output.data_ptr())
CPUInfer.sync()
end = time.perf_counter()
total_time += end - start
print('Time: ', total_time)
print('Iteration: ', test_iter)
print('Time per iteration: ', total_time / test_iter)
print('Bandwidth: ', hidden_size * intermediate_size * 3 * n_routed_experts * 2 * test_iter / total_time / 1000 / 1000 / 1000, 'GB/s')
print("All tasks completed.")
\ No newline at end of file
......@@ -3,8 +3,8 @@
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:23
* @LastEditors : chenht2022
* @LastEditTime : 2024-08-07 10:39:37
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
// Python bindings
......@@ -12,7 +12,6 @@
#include <iostream>
#include <memory>
#include "cpu_backend/cpuinfer.h"
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include "llamafile/flags.h"
#include "operators/llamafile/linear.h"
......@@ -26,239 +25,155 @@
namespace py = pybind11;
using namespace pybind11::literals;
// Binding functions for the Linear class
class LinearBindings {
public:
static void bind_forward(CPUInfer& cpuinfer, Linear* linear, py::args args, py::kwargs kwargs) {
auto input = args[0].cast<intptr_t>();
auto output = args[1].cast<intptr_t>();
cpuinfer.submit(&Linear::forward, linear,
(const void*)input, (void*)output);
}
static void bind_warm_up(CPUInfer& cpuinfer, Linear* linear, py::args args, py::kwargs kwargs) {
cpuinfer.submit(&Linear::warm_up, linear);
}
static void bind_functions(CPUInfer& cpuinfer, py::object func, py::args args, py::kwargs kwargs) {
auto linear = func.attr("__self__").cast<Linear*>();
std::string func_name = py::str(func.attr("__func__").attr("__name__"));
if (func_name == "forward") {
bind_forward(cpuinfer, linear, args, kwargs);
} else if (func_name == "warm_up") {
bind_warm_up(cpuinfer, linear, args, kwargs);
} else {
throw py::value_error("Unsupported function: " +
std::string(func_name));
class WarmUpBindinds {
public:
struct Args {
CPUInfer* cpuinfer;
Linear* linear;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&Linear::warm_up, args_->linear);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear) {
Args* args = new Args{nullptr, &linear};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
struct Args {
CPUInfer* cpuinfer;
Linear* linear;
int qlen;
const void* input;
void* output;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&Linear::forward, args_->linear, args_->qlen, args_->input, args_->output);
}
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(Linear& linear, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &linear, qlen, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
// Binding functions for the MLP class
class MLPBindings {
public:
static void bind_forward(CPUInfer& cpuinfer, MLP* mlp, py::args args, py::kwargs kwargs) {
auto input = args[0].cast<intptr_t>();
auto output = args[1].cast<intptr_t>();
cpuinfer.submit(&MLP::forward, mlp,
(const void*)input, (void*)output);
}
static void bind_warm_up(CPUInfer& cpuinfer, MLP* mlp, py::args args, py::kwargs kwargs) {
cpuinfer.submit(&MLP::warm_up, mlp);
}
static void bind_functions(CPUInfer& cpuinfer, py::object func, py::args args, py::kwargs kwargs) {
auto mlp = func.attr("__self__").cast<MLP*>();
std::string func_name = py::str(func.attr("__func__").attr("__name__"));
if (func_name == "forward") {
bind_forward(cpuinfer, mlp, args, kwargs);
} else if (func_name == "warm_up") {
bind_warm_up(cpuinfer, mlp, args, kwargs);
} else {
throw py::value_error("Unsupported function: " +
std::string(func_name));
class WarmUpBindinds {
public:
struct Args {
CPUInfer* cpuinfer;
MLP* mlp;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MLP::warm_up, args_->mlp);
}
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp) {
Args* args = new Args{nullptr, &mlp};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
struct Args {
CPUInfer* cpuinfer;
MLP* mlp;
int qlen;
const void* input;
void* output;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MLP::forward, args_->mlp, args_->qlen, args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MLP& mlp, int qlen, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &mlp, qlen, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
// Binding functions for the MOE class
class MOEBindings {
public:
static void bind_forward(CPUInfer& cpuinfer, MOE* moe, py::args args, py::kwargs kwargs) {
int qlen = args[0].cast<int>();
int k = args[1].cast<int>();
auto expert_ids = args[2].cast<intptr_t>();
auto weights = args[3].cast<intptr_t>();
auto input = args[4].cast<intptr_t>();
auto output = args[5].cast<intptr_t>();
cpuinfer.submit(&MOE::forward, moe,
qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output);
}
static void bind_warm_up(CPUInfer& cpuinfer, MOE* moe, py::args args, py::kwargs kwargs) {
cpuinfer.submit(&MOE::warm_up, moe);
}
static void bind_functions(CPUInfer& cpuinfer, py::object func, py::args args, py::kwargs kwargs) {
auto moe = func.attr("__self__").cast<MOE*>();
std::string func_name = py::str(func.attr("__func__").attr("__name__"));
if (func_name == "forward") {
bind_forward(cpuinfer, moe, args, kwargs);
} else if (func_name == "warm_up") {
bind_warm_up(cpuinfer, moe, args, kwargs);
} else {
throw py::value_error("Unsupported function: " +
std::string(func_name));
class WarmUpBindinds {
public:
struct Args {
CPUInfer* cpuinfer;
MOE* moe;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MOE::warm_up, args_->moe);
}
}
};
struct MOEForwardArgs {
CPUInfer* cpuinfer;
MOE* moe;
int qlen;
int k;
uint64_t* expert_ids;
float* weights;
void* input;
void* output;
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe) {
Args* args = new Args{nullptr, &moe};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
class ForwardBindings {
public:
struct Args {
CPUInfer* cpuinfer;
MOE* moe;
int qlen;
int k;
const uint64_t* expert_ids;
const float* weights;
const void* input;
void* output;
};
static void inner(void* args) {
Args* args_ = (Args*)args;
args_->cpuinfer->enqueue(&MOE::forward, args_->moe, args_->qlen, args_->k, args_->expert_ids, args_->weights, args_->input, args_->output);
}
static std::pair<intptr_t, intptr_t> cpuinfer_interface(MOE& moe, int qlen, int k, intptr_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) {
Args* args = new Args{nullptr, &moe, qlen, k, (const uint64_t*)expert_ids, (const float*)weights, (const void*)input, (void*)output};
return std::make_pair((intptr_t)&inner, (intptr_t)args);
}
};
};
void submit_moe_forward_with_host_args_ptr(void* host_args_ptr) {
MOEForwardArgs* host_args = (MOEForwardArgs*)host_args_ptr;
host_args->cpuinfer->submit(&MOE::forward, host_args->moe,
host_args->qlen, host_args->k, host_args->expert_ids, host_args->weights, host_args->input, host_args->output);
}
void cpuinfer_sync(void* host_args_ptr) {
CPUInfer* cpuinfer = (CPUInfer*)host_args_ptr;
cpuinfer->sync();
}
PYBIND11_MODULE(cpuinfer_ext, m) {
auto linear_module = m.def_submodule("linear");
py::class_<CPUInfer>(m, "CPUInfer")
.def(py::init<int>())
.def("submit", &CPUInfer::submit)
.def("submit_with_cuda_stream", &CPUInfer::submit_with_cuda_stream)
.def("sync", &CPUInfer::sync)
.def("sync_with_cuda_stream", &CPUInfer::sync_with_cuda_stream);
auto linear_module = m.def_submodule("linear");
py::class_<LinearConfig>(linear_module, "LinearConfig")
.def(py::init([](int hidden_size, int intermediate_size, int stride, intptr_t proj, int proj_type, int hidden_type) {
return LinearConfig(hidden_size, intermediate_size, stride, (void*)proj, (ggml_type)proj_type, (ggml_type)hidden_type);
.def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t proj, int proj_type, int hidden_type) {
return LinearConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)proj, (ggml_type)proj_type, (ggml_type)hidden_type);
}));
py::class_<Linear>(linear_module, "Linear")
.def(py::init<LinearConfig>())
.def("warm_up", [](Linear& linear) {
throw std::runtime_error("!!! Doing nothing, please use CPUInfer.submit to call it!!!\n");
})
.def("forward", [](Linear& linear, intptr_t input, intptr_t output) {
throw std::runtime_error("!!! Doing nothing, please use CPUInfer.submit to call it!!!\n");
});
.def("warm_up", &LinearBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &LinearBindings::ForwardBindings::cpuinfer_interface);
auto mlp_module = m.def_submodule("mlp");
py::class_<MLPConfig>(mlp_module, "MLPConfig")
.def(py::init([](int hidden_size, int intermediate_size, int stride, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) {
return MLPConfig(hidden_size, intermediate_size, stride, (void*)gate_proj, (void*)up_proj, (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type);
.def(py::init([](int hidden_size, int intermediate_size, int stride, int group_max_len, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) {
return MLPConfig(hidden_size, intermediate_size, stride, group_max_len, (void*)gate_proj, (void*)up_proj, (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type);
}));
py::class_<MLP>(mlp_module, "MLP")
.def(py::init<MLPConfig>())
.def("warm_up", [](MLP& mlp) {
throw std::runtime_error("!!! Doing nothing, please use CPUInfer.submit to call it!!!\n");
})
.def("forward", [](MLP& mlp, intptr_t input, intptr_t output) {
throw std::runtime_error("!!! Doing nothing, please use CPUInfer.submit to call it!!!\n");
});
.def("warm_up", &MLPBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MLPBindings::ForwardBindings::cpuinfer_interface);
auto moe_module = m.def_submodule("moe");
py::class_<MOEConfig>(moe_module, "MOEConfig")
.def(py::init([](int expert_num, int routed_expert_num, int hidden_size, int intermediate_size, int stride, int group_min_len, int group_max_len, intptr_t gate_proj, intptr_t up_proj, intptr_t down_proj, int gate_type, int up_type, int down_type, int hidden_type) {
return MOEConfig(expert_num, routed_expert_num, hidden_size, intermediate_size, stride, group_min_len, group_max_len, (void*)gate_proj, (void*)up_proj, (void*)down_proj, (ggml_type)gate_type, (ggml_type)up_type, (ggml_type)down_type, (ggml_type)hidden_type);
}));
py::class_<MOE>(moe_module, "MOE")
.def(py::init<MOEConfig>())
.def("warm_up", [](MOE& moe) {
throw std::runtime_error("!!! Doing nothing, please use CPUInfer.submit to call it!!!\n");
})
.def("forward", [](MOE& moe, int k, uint64_t expert_ids, intptr_t weights, intptr_t input, intptr_t output) {
throw std::runtime_error("!!! Doing nothing, please use CPUInfer.submit to call it!!!\n");
});
py::class_<CPUInfer>(m, "CPUInfer")
.def(py::init<int>())
.def("submit",
[linear_module, mlp_module, moe_module](CPUInfer& cpuinfer, py::object func, py::args args, py::kwargs kwargs) {
if (py::hasattr(func, "__self__") &&
py::hasattr(func, "__func__")) {
std::string class_name = py::str(func.attr("__self__")
.attr("__class__")
.attr("__name__"));
if (class_name == "Linear") {
LinearBindings::bind_functions(cpuinfer, func,
args, kwargs);
} else if (class_name == "MLP") {
MLPBindings::bind_functions(cpuinfer, func,
args, kwargs);
} else if (class_name == "MOE") {
MOEBindings::bind_functions(cpuinfer, func,
args, kwargs);
} else {
// handle other classes
throw py::type_error("Unsupported class type: " +
class_name);
}
} else {
// handle cases where func does not have __self__ or
// __func__
throw py::type_error(
"Invalid function object: missing "
"__self__ or __func__ attribute.");
}
})
.def("submit_with_cuda_stream",
[linear_module, mlp_module, moe_module](CPUInfer& cpuinfer, intptr_t user_cuda_stream, py::object func, py::args args, py::kwargs kwargs) {
if (py::hasattr(func, "__self__") &&
py::hasattr(func, "__func__")) {
std::string class_name = py::str(func.attr("__self__")
.attr("__class__")
.attr("__name__"));
if (class_name == "MOE") {
std::string func_name = py::str(func.attr("__func__").attr("__name__"));
if (func_name == "forward") {
auto moe = func.attr("__self__").cast<MOE*>();
int qlen = args[0].cast<int>();
int k = args[1].cast<int>();
auto expert_ids = args[2].cast<intptr_t>();
auto weights = args[3].cast<intptr_t>();
auto input = args[4].cast<intptr_t>();
auto output = args[5].cast<intptr_t>();
MOEForwardArgs* moe_forward_args = new MOEForwardArgs{&cpuinfer, moe, qlen, k, (uint64_t*)expert_ids, (float*)weights, (void*)input, (void*)output};
// submit_moe_forward_with_host_args_ptr(moe_forward_args);
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)submit_moe_forward_with_host_args_ptr, moe_forward_args);
} else {
throw py::value_error("Unsupported function: " +
std::string(func_name));
}
} else {
// handle other classes
throw py::type_error("Unsupported class type: " +
class_name);
}
} else {
// handle cases where func does not have __self__ or
// __func__
throw py::type_error(
"Invalid function object: missing "
"__self__ or __func__ attribute.");
}
})
.def("sync_with_cuda_stream", [](CPUInfer& cpuinfer, intptr_t user_cuda_stream) {
// cpuinfer_sync((void*)(&cpuinfer));
cudaLaunchHostFunc((cudaStream_t)user_cuda_stream, (cudaHostFn_t)cpuinfer_sync, (void*)(&cpuinfer));
})
.def("sync", &CPUInfer::sync);
.def("warm_up", &MOEBindings::WarmUpBindinds::cpuinfer_interface)
.def("forward", &MOEBindings::ForwardBindings::cpuinfer_interface);
}
import math
import os
import time
from logging import getLogger
import torch
import torch.nn as nn
import transformers
from .quantizer import Quantizer
logger = getLogger(__name__)
torch.backends.cuda.matmul.allow_tf32 = False
torch.backends.cudnn.allow_tf32 = False
class GPTQ:
def __init__(self, layer):
self.layer = layer
self.dev = self.layer.weight.device
W = layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.pytorch_utils.Conv1D):
W = W.t()
self.rows = W.shape[0]
self.columns = W.shape[1]
self.H = torch.zeros((self.columns, self.columns), device=self.dev)
self.nsamples = 0
self.quantizer = Quantizer()
def add_batch(self, inp, out):
if os.environ.get("DEBUG"):
self.inp1 = inp
self.out1 = out
if len(inp.shape) == 2:
inp = inp.unsqueeze(0)
tmp = inp.shape[0]
if isinstance(self.layer, nn.Linear) or isinstance(self.layer, transformers.Conv1D):
if len(inp.shape) == 3:
inp = inp.reshape((-1, inp.shape[-1]))
inp = inp.t()
if isinstance(self.layer, nn.Conv2d):
unfold = nn.Unfold(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.stride,
)
inp = unfold(inp)
inp = inp.permute([1, 0, 2])
inp = inp.flatten(1)
self.H *= self.nsamples / (self.nsamples + tmp)
self.nsamples += tmp
# inp = inp.float()
inp = math.sqrt(2 / self.nsamples) * inp.float()
# self.H += 2 / self.nsamples * inp.matmul(inp.t())
self.H += inp.matmul(inp.t())
def fasterquant(
self,
blocksize=128,
percdamp=0.01,
group_size=-1,
actorder=False,
static_groups=False,
):
W = self.layer.weight.data.clone()
if isinstance(self.layer, nn.Conv2d):
W = W.flatten(1)
if isinstance(self.layer, transformers.Conv1D):
W = W.t()
W = W.float()
tick = time.time()
if not self.quantizer.ready():
self.quantizer.find_params(W, weight=True)
H = self.H
del self.H
dead = torch.diag(H) == 0
H[dead, dead] = 1
W[:, dead] = 0
g_idx = []
scale = []
zero = []
now_idx = 1
if static_groups:
import copy
groups = []
for i in range(0, self.columns, group_size):
quantizer = copy.deepcopy(self.quantizer)
quantizer.find_params(W[:, i : (i + group_size)], weight=True)
scale.append(quantizer.scale)
zero.append(quantizer.zero)
groups.append(quantizer)
if actorder:
perm = torch.argsort(torch.diag(H), descending=True)
W = W[:, perm]
H = H[perm][:, perm]
invperm = torch.argsort(perm)
Losses = torch.zeros_like(W)
Q = torch.zeros_like(W)
damp = percdamp * torch.mean(torch.diag(H))
diag = torch.arange(self.columns, device=self.dev)
H[diag, diag] += damp
H = torch.linalg.cholesky(H)
H = torch.cholesky_inverse(H)
H = torch.linalg.cholesky(H, upper=True)
Hinv = H
for i1 in range(0, self.columns, blocksize):
i2 = min(i1 + blocksize, self.columns)
count = i2 - i1
W1 = W[:, i1:i2].clone()
Q1 = torch.zeros_like(W1)
Err1 = torch.zeros_like(W1)
Losses1 = torch.zeros_like(W1)
Hinv1 = Hinv[i1:i2, i1:i2]
for i in range(count):
w = W1[:, i]
d = Hinv1[i, i]
if group_size != -1:
if not static_groups:
if (i1 + i) % group_size == 0:
self.quantizer.find_params(W[:, (i1 + i) : (i1 + i + group_size)], weight=True)
if ((i1 + i) // group_size) - now_idx == -1:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
now_idx += 1
else:
idx = i1 + i
if actorder:
idx = perm[idx]
self.quantizer = groups[idx // group_size]
q = self.quantizer.quantize(w.unsqueeze(1)).flatten()
Q1[:, i] = q
Losses1[:, i] = (w - q) ** 2 / d**2
err1 = (w - q) / d
W1[:, i:] -= err1.unsqueeze(1).matmul(Hinv1[i, i:].unsqueeze(0))
Err1[:, i] = err1
Q[:, i1:i2] = Q1
Losses[:, i1:i2] = Losses1 / 2
W[:, i2:] -= Err1.matmul(Hinv[i1:i2, i2:])
if os.environ.get("DEBUG"):
self.layer.weight.data[:, :i2] = Q[:, :i2]
self.layer.weight.data[:, i2:] = W[:, i2:]
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
logger.debug(torch.sum(Losses))
torch.cuda.synchronize()
logger.info(f"duration: {(time.time() - tick)}")
logger.info(f"avg loss: {torch.sum(Losses).item() / self.nsamples}")
group_size = group_size if group_size != -1 else self.columns
if static_groups and actorder:
g_idx = [perm[i] // group_size for i in range(self.columns)]
else:
g_idx = [i // group_size for i in range(self.columns)]
g_idx = torch.tensor(g_idx, dtype=torch.int32, device=Q.device)
if actorder:
Q = Q[:, invperm]
g_idx = g_idx[invperm]
if isinstance(self.layer, transformers.Conv1D):
Q = Q.t()
self.layer.weight.data = Q.reshape(self.layer.weight.shape).type_as(self.layer.weight.data)
if os.environ.get("DEBUG"):
logger.debug(torch.sum((self.layer(self.inp1) - self.out1) ** 2))
if scale == []:
scale.append(self.quantizer.scale)
zero.append(self.quantizer.zero)
scale = torch.cat(scale, dim=1)
zero = torch.cat(zero, dim=1)
return scale, zero, g_idx
def free(self):
if os.environ.get("DEBUG"):
self.inp1 = None
self.out1 = None
self.H = None
self.Losses = None
self.Trace = None
torch.cuda.empty_cache()
__all__ = ["GPTQ"]
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
import torch
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
logger = init_logger(__name__)
GPTQ_MARLIN_TILE = 16
GPTQ_MARLIN_MIN_THREAD_N = 64
GPTQ_MARLIN_MIN_THREAD_K = 128
GPTQ_MARLIN_MAX_PARALLEL = 16
GPTQ_MARLIN_SUPPORTED_NUM_BITS = [4, 8]
GPTQ_MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
GPTQ_MARLIN_SUPPORTED_SYM = [True]
# Permutations for Marlin scale shuffling
def get_scale_perms(num_bits: int):
scale_perm: List[int] = []
for i in range(8):
scale_perm.extend([i + 8 * j for j in range(8)])
scale_perm_single: List[int] = []
for i in range(4):
scale_perm_single.extend(
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
return scale_perm, scale_perm_single
def get_pack_factor(num_bits: int):
assert (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
), f"Unsupported num_bits = {num_bits}"
return 32 // num_bits
def marlin_permute_scales(s: torch.Tensor, size_k: int, size_n: int,
group_size: int, num_bits: int):
scale_perm, scale_perm_single = get_scale_perms(num_bits)
if group_size < size_k and group_size != -1:
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
else:
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
s = s.reshape((-1, size_n)).contiguous()
return s
class GPTQMarlinConfig(QuantizationConfig):
"""Config class for GPTQ Marlin"""
def __init__(self, weight_bits: int, group_size: int, desc_act: bool,
is_sym: bool) -> None:
if desc_act and group_size == -1:
# In this case, act_order == True is the same as act_order == False
# (since we have only one group per output channel)
desc_act = False
self.weight_bits = weight_bits
self.group_size = group_size
self.desc_act = desc_act
self.is_sym = is_sym
# Verify
if self.weight_bits not in GPTQ_MARLIN_SUPPORTED_NUM_BITS:
raise ValueError(
f"Marlin does not support weight_bits = {self.weight_bits}. "
f"Only weight_bits = {GPTQ_MARLIN_SUPPORTED_NUM_BITS} "
"are supported.")
if self.group_size not in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES:
raise ValueError(
f"Marlin does not support group_size = {self.group_size}. "
f"Only group_sizes = {GPTQ_MARLIN_SUPPORTED_GROUP_SIZES} "
"are supported.")
if self.is_sym not in GPTQ_MARLIN_SUPPORTED_SYM:
raise ValueError(
f"Marlin does not support is_sym = {self.is_sym}. "
f"Only sym = {GPTQ_MARLIN_SUPPORTED_SYM} are supported.")
# Init
self.pack_factor = get_pack_factor(weight_bits)
self.tile_size = GPTQ_MARLIN_TILE
self.min_thread_n = GPTQ_MARLIN_MIN_THREAD_N
self.min_thread_k = GPTQ_MARLIN_MIN_THREAD_K
self.max_parallel = GPTQ_MARLIN_MAX_PARALLEL
def __repr__(self) -> str:
return (f"GPTQMarlinConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"desc_act={self.desc_act})")
@classmethod
def get_name(cls) -> str:
return "gptq_marlin"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "GPTQMarlinConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
desc_act = cls.get_from_keys(config, ["desc_act"])
is_sym = cls.get_from_keys(config, ["sym"])
return cls(weight_bits, group_size, desc_act, is_sym)
@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
can_convert = cls.is_marlin_compatible(hf_quant_cfg)
is_valid_user_quant = (user_quant is None or user_quant == "marlin")
if can_convert and is_valid_user_quant:
msg = ("The model is convertible to {} during runtime."
" Using {} kernel.".format(cls.get_name(), cls.get_name()))
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "gptq":
logger.info("Detected that the model can run with gptq_marlin"
", however you specified quantization=gptq explicitly,"
" so forcing gptq. Use quantization=gptq_marlin for"
" faster inference")
return None
def get_quant_method(
self,
layer: torch.nn.Module) -> Optional["GPTQMarlinLinearMethod"]:
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
return None
def get_scaled_act_names(self) -> List[str]:
return []
@classmethod
def is_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
# If we cannot find the info needed in the config, cannot convert.
if (num_bits is None or group_size is None or sym is None
or desc_act is None):
return False
# If the capability of the device is too low, cannot convert.
major, minor = torch.cuda.get_device_capability()
device_capability = major * 10 + minor
if device_capability < cls.get_min_capability():
return False
# Otherwise, can convert if model satisfies marlin constraints.
return (num_bits in GPTQ_MARLIN_SUPPORTED_NUM_BITS
and group_size in GPTQ_MARLIN_SUPPORTED_GROUP_SIZES
and sym in GPTQ_MARLIN_SUPPORTED_SYM)
class GPTQMarlinState(Enum):
REPACK = enum.auto()
READY = enum.auto()
class GPTQMarlinLinearMethod(LinearMethodBase):
"""Linear method for GPTQ Marlin.
Args:
quant_config: The GPTQ Marlin quantization config.
"""
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
# Validate dtype
if params_dtype not in [torch.float16, torch.bfloat16]:
raise ValueError(f"The params dtype must be float16 "
f"or bfloat16, but got {params_dtype}")
# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_thread_n != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f" min_thread_n = {self.quant_config.min_thread_n}.")
# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_thread_k != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible "
f"by min_thread_k = {self.quant_config.min_thread_k}.")
if (group_size < input_size
and input_size_per_partition % group_size != 0):
raise ValueError(
f"Weight input_size_per_partition = {input_size_per_partition}"
f" is not divisible by group_size = {group_size}.")
# Detect sharding of scales/zp
# By default, no sharding over "input dim"
scales_and_zp_size = input_size // group_size
scales_and_zp_input_dim = None
if self.quant_config.desc_act:
# Act-order case
assert self.quant_config.group_size != -1
is_k_full = input_size_per_partition == input_size
else:
# No act-order case
# K is always full due to full alignment with
# group-size and shard of scales/zp
is_k_full = True
# If this is a row-parallel case, then shard scales/zp
if (input_size != input_size_per_partition
and self.quant_config.group_size != -1):
scales_and_zp_size = input_size_per_partition // group_size
scales_and_zp_input_dim = 0
# Init buffers
# Quantized weights
qweight = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.pack_factor,
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
**extra_weight_attrs,
"input_dim": 0,
"output_dim": 1,
"packed_dim": 0,
"pack_factor": self.quant_config.pack_factor,
},
)
# Activation order
g_idx = Parameter(
torch.empty(
input_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs(
g_idx,
{
**extra_weight_attrs, "input_dim": 0,
"ignore_warning": True
},
)
g_idx_sort_indices = torch.empty(
g_idx.shape,
dtype=torch.int32,
)
# Scales
scales = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition,
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
},
)
# Quantized zero-points
qzeros = Parameter(
torch.empty(
scales_and_zp_size,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
device="meta",
),
requires_grad=False,
)
set_weight_attrs(
qzeros,
{
**extra_weight_attrs,
"input_dim": scales_and_zp_input_dim,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
},
)
# Allocate marlin workspace
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_thread_n) * self.quant_config.max_parallel
workspace = torch.zeros(max_workspace_size,
dtype=torch.int,
requires_grad=False)
layer.register_parameter("qweight", qweight)
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.g_idx_sort_indices = g_idx_sort_indices
layer.workspace = workspace
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = is_k_full
layer.marlin_state = GPTQMarlinState.REPACK
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
from logging import getLogger
import torch
import torch.nn as nn
logger = getLogger(__name__)
def quantize(x, scale, zero, maxq):
if maxq < 0:
return (x > scale / 2).float() * scale + (x < zero / 2).float() * zero
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
return scale * (q - zero)
class Quantizer(nn.Module):
def __init__(self, shape=1):
super(Quantizer, self).__init__()
self.register_buffer("maxq", torch.tensor(0))
self.register_buffer("scale", torch.zeros(shape))
self.register_buffer("zero", torch.zeros(shape))
def configure(
self,
bits,
perchannel=False,
sym=True,
mse=False,
norm=2.4,
grid=100,
maxshrink=0.8,
trits=False,
):
self.maxq = torch.tensor(2**bits - 1)
self.perchannel = perchannel
self.sym = sym
self.mse = mse
self.norm = norm
self.grid = grid
self.maxshrink = maxshrink
if trits:
self.maxq = torch.tensor(-1)
def find_params(self, x, weight=False):
dev = x.device
self.maxq = self.maxq.to(dev)
shape = x.shape
if self.perchannel:
if weight:
x = x.flatten(1)
else:
if len(shape) == 4:
x = x.permute([1, 0, 2, 3])
x = x.flatten(1)
if len(shape) == 3:
x = x.reshape((-1, shape[-1])).t()
if len(shape) == 2:
x = x.t()
else:
x = x.flatten().unsqueeze(0)
tmp = torch.zeros(x.shape[0], device=dev)
xmin = torch.minimum(x.min(1)[0], tmp)
xmax = torch.maximum(x.max(1)[0], tmp)
if self.sym:
xmax = torch.maximum(torch.abs(xmin), xmax)
tmp = xmin < 0
if torch.any(tmp):
xmin[tmp] = -xmax[tmp]
tmp = (xmin == 0) & (xmax == 0)
xmin[tmp] = -1
xmax[tmp] = +1
if self.maxq < 0:
self.scale = xmax
self.zero = xmin
else:
self.scale = (xmax - xmin) / self.maxq
if self.sym:
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
else:
self.zero = torch.round(-xmin / self.scale)
if self.mse:
best = torch.full([x.shape[0]], float("inf"), device=dev)
for i in range(int(self.maxshrink * self.grid)):
p = 1 - i / self.grid
xmin1 = p * xmin
xmax1 = p * xmax
scale1 = (xmax1 - xmin1) / self.maxq
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
q -= x
q.abs_()
q.pow_(self.norm)
err = torch.sum(q, 1)
tmp = err < best
if torch.any(tmp):
best[tmp] = err[tmp]
self.scale[tmp] = scale1[tmp]
self.zero[tmp] = zero1[tmp]
if not self.perchannel:
if weight:
tmp = shape[0]
else:
tmp = shape[1] if len(shape) != 3 else shape[2]
self.scale = self.scale.repeat(tmp)
self.zero = self.zero.repeat(tmp)
if weight:
shape = [-1] + [1] * (len(shape) - 1)
self.scale = self.scale.reshape(shape)
self.zero = self.zero.reshape(shape)
return
if len(shape) == 4:
self.scale = self.scale.reshape((1, -1, 1, 1))
self.zero = self.zero.reshape((1, -1, 1, 1))
if len(shape) == 3:
self.scale = self.scale.reshape((1, 1, -1))
self.zero = self.zero.reshape((1, 1, -1))
if len(shape) == 2:
self.scale = self.scale.unsqueeze(0)
self.zero = self.zero.unsqueeze(0)
def quantize(self, x):
if self.ready():
return quantize(x, self.scale, self.zero, self.maxq)
return x
def enabled(self):
return self.maxq > 0
def ready(self):
return torch.all(self.scale != 0)
__all__ = ["Quantizer"]
import torch
import enum
from enum import Enum
from typing import Any, Dict, List, Optional
from torch.nn.parameter import Parameter
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
reshaped_x = x.reshape(-1, x.shape[-1])
size_m = reshaped_x.shape[0]
part_size_n = layer.output_size_per_partition
part_size_k = layer.input_size_per_partition
full_size_k = layer.input_size
out_shape = x.shape[:-1] + (part_size_n, )
if layer.marlin_state == GPTQMarlinState.REPACK:
layer.marlin_state = GPTQMarlinState.READY
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
cur_device = layer.qweight.device
# Process act_order
if self.quant_config.desc_act:
# Get sorting based on g_idx
g_idx_sort_indices = torch.argsort(layer.g_idx).to(torch.int)
sorted_g_idx = layer.g_idx[g_idx_sort_indices]
replace_tensor("g_idx", sorted_g_idx)
replace_tensor("g_idx_sort_indices", g_idx_sort_indices)
else:
# Reset g_idx related tensors
layer.g_idx = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
layer.g_idx_sort_indices = Parameter(
torch.empty(0, dtype=torch.int, device=cur_device),
requires_grad=False,
)
# Repack weights
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
layer.g_idx_sort_indices,
part_size_k,
part_size_n,
self.quant_config.weight_bits,
)
replace_tensor("qweight", marlin_qweight)
# Permute scales
scales_size_k = part_size_k
scales_size_n = part_size_n
if self.quant_config.desc_act:
scales_size_k = full_size_k
marlin_scales = marlin_permute_scales(
layer.scales,
scales_size_k,
scales_size_n,
self.quant_config.group_size,
self.quant_config.weight_bits,
)
replace_tensor("scales", marlin_scales)
output = ops.gptq_marlin_gemm(
reshaped_x,
layer.qweight,
layer.scales,
layer.g_idx,
layer.g_idx_sort_indices,
layer.workspace,
self.quant_config.weight_bits,
size_m,
part_size_n,
part_size_k,
layer.is_k_full,
)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
......@@ -220,7 +220,7 @@ def compute_max_diff(output, output_ref):
class MarlinWorkspace:
def __init__(self, out_features, min_thread_n, max_parallel):
def __init__(self, out_features, min_thread_n, max_parallel, device):
assert (out_features % min_thread_n == 0), (
"out_features = {} is undivisible by min_thread_n = {}".format(
out_features, min_thread_n))
......@@ -229,4 +229,4 @@ class MarlinWorkspace:
self.scratch = torch.zeros(max_workspace_size,
dtype=torch.int,
device="cuda")
device=device)
......@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:34:58
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
......@@ -13,9 +13,15 @@ Linear::Linear(LinearConfig config) {
config_ = config;
proj_ = config_.proj;
input_fp32_.resize(config_.input_size);
proj_input_.resize(config_.input_size * 4);
proj_output_.resize(config_.output_size);
std::vector<std::pair<void**, uint64_t>> mem_requests;
mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.input_size});
mem_requests.push_back({(void**)&proj_input_, config_.group_max_len * config_.input_size * ggml_type_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.proj_type).vec_dot_type)});
mem_requests.push_back({(void**)&proj_output_, sizeof(float) * config_.group_max_len * config_.output_size});
shared_mem_buffer.alloc(this, mem_requests);
}
Linear::~Linear() {
shared_mem_buffer.dealloc(this);
}
void Linear::warm_up(Backend* backend) {
......@@ -26,22 +32,42 @@ void Linear::warm_up(Backend* backend) {
input_fp32[i] = 0;
}
from_float(input_fp32.data(), input.data(), config_.input_size, config_.hidden_type);
forward(input.data(), output.data(), backend);
forward_many(1, input.data(), output.data(), backend);
}
void Linear::forward(const void* input, void* output, Backend* backend) {
void Linear::forward_many(int qlen, const void* input, void* output, Backend* backend) {
const void* proj_input_ptr;
if (config_.hidden_type == ggml_internal_get_type_traits(config_.proj_type).vec_dot_type) {
proj_input_ptr = input;
} else {
to_float(input, input_fp32_.data(), config_.input_size, config_.hidden_type);
from_float(input_fp32_.data(), proj_input_.data(), config_.input_size, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type);
proj_input_ptr = proj_input_.data();
to_float(input, input_fp32_, qlen * config_.input_size, config_.hidden_type);
from_float(input_fp32_, proj_input_, qlen * config_.input_size, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type);
proj_input_ptr = proj_input_;
}
int nth = config_.output_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id % nth;
llamafile_sgemm(config_.output_size, 1, config_.input_size / ggml_blck_size(config_.proj_type), proj_, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_.data(), config_.output_size, ith, nth, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
int ith = task_id;
void* proj_ptr = (uint8_t*)proj_ + ith * config_.stride * config_.input_size * ggml_type_size(config_.proj_type) / ggml_blck_size(config_.proj_type);
float* proj_output_ptr = proj_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.input_size / ggml_blck_size(config_.proj_type), proj_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_input_ptr, config_.input_size / ggml_blck_size(config_.proj_type), proj_output_ptr, config_.output_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.proj_type, ggml_internal_get_type_traits(config_.proj_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
for (int i = 0; i < qlen; i++) {
float* output_fp32_ptr = proj_output_ + i * config_.output_size + ith * config_.stride;
void* output_ptr = (uint8_t*)output + i * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
});
from_float(proj_output_.data(), output, config_.output_size, config_.hidden_type);
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(proj_output_, output, qlen * config_.output_size, config_.hidden_type);
}
}
void Linear::forward(int qlen, const void* input, void* output, Backend* backend) {
if (qlen <= 0) {
return;
}
int forward_len = std::min(qlen, config_.group_max_len);
forward_many(forward_len, input, output, backend);
forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.input_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.output_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
}
\ No newline at end of file
......@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:00
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
......@@ -22,34 +22,38 @@
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct LinearConfig {
int input_size;
int output_size;
int stride;
int group_max_len;
void* proj;
ggml_type proj_type;
ggml_type hidden_type;
LinearConfig() {}
LinearConfig(int input_size, int output_size, int stride, void* proj, ggml_type proj_type, ggml_type hidden_type)
: input_size(input_size), output_size(output_size), stride(stride), proj(proj), proj_type(proj_type), hidden_type(hidden_type) {}
LinearConfig(int input_size, int output_size, int stride, int group_max_len, void* proj, ggml_type proj_type, ggml_type hidden_type)
: input_size(input_size), output_size(output_size), stride(stride), group_max_len(group_max_len), proj(proj), proj_type(proj_type), hidden_type(hidden_type) {}
};
class Linear {
public:
Linear(LinearConfig);
~Linear();
void warm_up(Backend* backend);
void forward(const void* input, void* output, Backend* backend);
void forward_many(int qlen, const void* input, void* output, Backend* backend);
void forward(int qlen, const void* input, void* output, Backend* backend);
private:
LinearConfig config_;
void* proj_; // [output_size * input_size ( /32 if quantized)]
std::vector<float> input_fp32_; // [input_size]
std::vector<uint8_t> proj_input_; // [input_size * 4]
std::vector<float> proj_output_; // [output_size]
float* input_fp32_; // [group_max_len * input_size]
uint8_t* proj_input_; // [group_max_len * input_size * ggml_type_size(ggml_internal_get_type_traits(proj_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(proj_type).vec_dot_type)]
float* proj_output_; // [group_max_len * output_size]
};
#endif
\ No newline at end of file
......@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Date : 2024-07-16 10:43:18
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
......@@ -15,14 +15,20 @@ MLP::MLP(MLPConfig config) {
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
input_fp32_.resize(config_.hidden_size);
gate_input_.resize(config_.hidden_size * 4);
up_input_.resize(config_.hidden_size * 4);
gate_output_.resize(config_.intermediate_size);
up_output_.resize(config_.intermediate_size);
intermediate_fp32_.resize(config_.intermediate_size);
down_input_.resize(config_.intermediate_size * 4);
down_output_.resize(config_.hidden_size);
std::vector<std::pair<void**, uint64_t>> mem_requests;
mem_requests.push_back({(void**)&input_fp32_, sizeof(float) * config_.group_max_len * config_.hidden_size});
mem_requests.push_back({(void**)&gate_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
mem_requests.push_back({(void**)&up_input_, config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
mem_requests.push_back({(void**)&gate_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size});
mem_requests.push_back({(void**)&up_output_, sizeof(float) * config_.group_max_len * config_.intermediate_size});
mem_requests.push_back({(void**)&intermediate_fp32_, sizeof(float) * config_.group_max_len * config_.intermediate_size});
mem_requests.push_back({(void**)&down_input_, config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});
mem_requests.push_back({(void**)&down_output_, sizeof(float) * config_.group_max_len * config_.hidden_size});
shared_mem_buffer.alloc(this, mem_requests);
}
MLP::~MLP() {
shared_mem_buffer.dealloc(this);
}
void MLP::warm_up(Backend* backend) {
......@@ -33,33 +39,33 @@ void MLP::warm_up(Backend* backend) {
input_fp32[i] = 0;
}
from_float(input_fp32.data(), input.data(), config_.hidden_size, config_.hidden_type);
forward(input.data(), output.data(), backend);
forward_many(1, input.data(), output.data(), backend);
}
static float act_fn(float x) {
return x / (1.0f + expf(-x));
}
void MLP::forward(const void* input, void* output, Backend* backend) {
void MLP::forward_many(int qlen, const void* input, void* output, Backend* backend) {
const void* gate_input_ptr;
const void* up_input_ptr;
if (config_.hidden_type == ggml_internal_get_type_traits(config_.gate_type).vec_dot_type && config_.hidden_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
gate_input_ptr = up_input_ptr = input;
} else {
to_float(input, input_fp32_.data(), config_.hidden_size, config_.hidden_type);
to_float(input, input_fp32_, qlen * config_.hidden_size, config_.hidden_type);
if (ggml_internal_get_type_traits(config_.gate_type).vec_dot_type == ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(input_fp32_.data(), gate_input_.data(), config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = up_input_ptr = gate_input_.data();
from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = up_input_ptr = gate_input_;
} else {
if (config_.hidden_type != ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) {
from_float(input_fp32_.data(), gate_input_.data(), config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = gate_input_.data();
from_float(input_fp32_, gate_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
gate_input_ptr = gate_input_;
} else {
gate_input_ptr = input;
}
if (config_.hidden_type != ggml_internal_get_type_traits(config_.up_type).vec_dot_type) {
from_float(input_fp32_.data(), up_input_.data(), config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
up_input_ptr = up_input_.data();
from_float(input_fp32_, up_input_, qlen * config_.hidden_size, ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
up_input_ptr = up_input_;
} else {
up_input_ptr = input;
}
......@@ -69,35 +75,49 @@ void MLP::forward(const void* input, void* output, Backend* backend) {
backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id;
void* gate_proj_ptr = (uint8_t*)gate_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.gate_type) / ggml_blck_size(config_.gate_type);
float* gate_output_ptr = gate_output_.data() + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
float* gate_output_ptr = gate_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_proj_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_input_ptr, config_.hidden_size / ggml_blck_size(config_.gate_type), gate_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.gate_type, ggml_internal_get_type_traits(config_.gate_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
void* up_proj_ptr = (uint8_t*)up_proj_ + ith * config_.stride * config_.hidden_size * ggml_type_size(config_.up_type) / ggml_blck_size(config_.up_type);
float* up_output_ptr = up_output_.data() + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = ith * config_.stride; i < (ith + 1) * config_.stride; i++) {
intermediate_fp32_[i] = act_fn(gate_output_[i]) * up_output_[i];
}
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
float* intermediate_fp32_ptr = intermediate_fp32_.data() + ith * config_.stride;
void* down_input_ptr = (uint8_t*)down_input_.data() + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
float* up_output_ptr = up_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.hidden_size / ggml_blck_size(config_.up_type), up_proj_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_input_ptr, config_.hidden_size / ggml_blck_size(config_.up_type), up_output_ptr, config_.intermediate_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.up_type, ggml_internal_get_type_traits(config_.up_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
for (int i = 0; i < qlen; i++) {
for (int j = ith * config_.stride; j < (ith + 1) * config_.stride; j++) {
intermediate_fp32_[i * config_.intermediate_size + j] = act_fn(gate_output_[i * config_.intermediate_size + j]) * up_output_[i * config_.intermediate_size + j];
}
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) == 0) {
float* intermediate_fp32_ptr = intermediate_fp32_ + i * config_.intermediate_size + ith * config_.stride;
void* down_input_ptr = (uint8_t*)down_input_ + i * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) + ith * config_.stride * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_ptr, down_input_ptr, config_.stride, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
}
});
if (config_.stride % ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) != 0) {
from_float(intermediate_fp32_.data(), down_input_.data(), config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
from_float(intermediate_fp32_, down_input_, qlen * config_.intermediate_size, ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
}
nth = config_.hidden_size / config_.stride;
backend->do_work_stealing_job(nth, [&](int task_id) {
int ith = task_id;
void* down_proj_ptr = (uint8_t*)down_proj_ + ith * config_.stride * config_.intermediate_size * ggml_type_size(config_.down_type) / ggml_blck_size(config_.down_type);
float* down_output_ptr = down_output_.data() + ith * config_.stride;
llamafile_sgemm(config_.stride, 1, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_.data(), config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.stride, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
float* down_output_ptr = down_output_ + ith * config_.stride;
llamafile_sgemm(config_.stride, qlen, config_.intermediate_size / ggml_blck_size(config_.down_type), down_proj_ptr, config_.intermediate_size / ggml_blck_size(config_.down_type), down_input_, config_.intermediate_size / ggml_blck_size(config_.down_type), down_output_ptr, config_.hidden_size, 0, 1, GGML_TASK_TYPE_COMPUTE, config_.down_type, ggml_internal_get_type_traits(config_.down_type).vec_dot_type, GGML_TYPE_F32, GGML_PREC_DEFAULT);
if (config_.stride % ggml_blck_size(config_.hidden_type) == 0) {
void* output_ptr = (uint8_t*)output + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(down_output_ptr, output_ptr, config_.stride, config_.hidden_type);
for (int i = 0; i < qlen; i++) {
float* output_fp32_ptr = down_output_ + i * config_.hidden_size + ith * config_.stride;
void* output_ptr = (uint8_t*)output + i * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type) + ith * config_.stride * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type);
from_float(output_fp32_ptr, output_ptr, config_.stride, config_.hidden_type);
}
}
});
if (config_.stride % ggml_blck_size(config_.hidden_type) != 0) {
from_float(down_output_.data(), output, config_.hidden_size, config_.hidden_type);
from_float(down_output_, output, qlen * config_.hidden_size, config_.hidden_type);
}
}
void MLP::forward(int qlen, const void* input, void* output, Backend* backend) {
if (qlen <= 0) {
return;
}
int forward_len = std::min(qlen, config_.group_max_len);
forward_many(forward_len, input, output, backend);
forward(qlen - forward_len, (uint8_t*)input + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), (uint8_t*)output + forward_len * config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type), backend);
}
\ No newline at end of file
......@@ -3,7 +3,7 @@
* @Author : chenht2022
* @Date : 2024-07-12 10:07:58
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:06
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
......@@ -22,11 +22,13 @@
#include "llama.cpp/ggml-quants.h"
#include "llama.cpp/ggml.h"
#include "llamafile/sgemm.h"
#include "shared_mem_buffer.h"
struct MLPConfig {
int hidden_size;
int intermediate_size;
int stride;
int group_max_len;
void* gate_proj;
void* up_proj;
void* down_proj;
......@@ -37,15 +39,17 @@ struct MLPConfig {
MLPConfig() {}
MLPConfig(int hidden_size, int intermediate_size, int stride, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
: hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
MLPConfig(int hidden_size, int intermediate_size, int stride, int group_max_len, void* gate_proj, void* up_proj, void* down_proj, ggml_type gate_type, ggml_type up_type, ggml_type down_type, ggml_type hidden_type)
: hidden_size(hidden_size), intermediate_size(intermediate_size), stride(stride), group_max_len(group_max_len), gate_proj(gate_proj), up_proj(up_proj), down_proj(down_proj), gate_type(gate_type), up_type(up_type), down_type(down_type), hidden_type(hidden_type) {}
};
class MLP {
public:
MLP(MLPConfig);
~MLP();
void warm_up(Backend* backend);
void forward(const void* input, void* output, Backend* backend);
void forward_many(int qlen, const void* input, void* output, Backend* backend);
void forward(int qlen, const void* input, void* output, Backend* backend);
private:
MLPConfig config_;
......@@ -53,14 +57,14 @@ class MLP {
void* up_proj_; // [intermediate_size * hidden_size ( /32 if quantized)]
void* down_proj_; // [hidden_size * intermediate_size ( /32 if quantized)]
std::vector<float> input_fp32_; // [hidden_size]
std::vector<uint8_t> gate_input_; // [hidden_size * 4]
std::vector<uint8_t> up_input_; // [hidden_size * 4]
std::vector<float> gate_output_; // [intermediate_size]
std::vector<float> up_output_; // [intermediate_size]
std::vector<float> intermediate_fp32_; // [intermediate_size]
std::vector<uint8_t> down_input_; // [intermediate_size * 4]
std::vector<float> down_output_; // [hidden_size]
float* input_fp32_; // [group_max_len * hidden_size]
uint8_t* gate_input_; // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(gate_type).vec_dot_type)]
uint8_t* up_input_; // [group_max_len * hidden_size * ggml_type_size(ggml_internal_get_type_traits(up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(up_type).vec_dot_type)]
float* gate_output_; // [group_max_len * intermediate_size]
float* up_output_; // [group_max_len * intermediate_size]
float* intermediate_fp32_; // [group_max_len * intermediate_size]
uint8_t* down_input_; // [group_max_len * intermediate_size * ggml_type_size(ggml_internal_get_type_traits(down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(down_type).vec_dot_type)]
float* down_output_; // [group_max_len * hidden_size]
};
#endif
\ No newline at end of file
/**
* @Description :
* @Description :
* @Author : chenht2022
* @Date : 2024-07-22 02:03:22
* @Version : 1.0.0
* @LastEditors : chenht2022
* @LastEditors : chenht2022
* @LastEditTime : 2024-07-25 10:35:07
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/
#include "moe.h"
#include <iostream>
#include <cstdint>
uint8_t* MOE::buffer_ = nullptr;
MOE::MOE(MOEConfig config) {
config_ = config;
gate_proj_ = config_.gate_proj;
up_proj_ = config_.up_proj;
down_proj_ = config_.down_proj;
if (MOE::buffer_ == nullptr) {
uint64_t buffer_size = 0;
buffer_size += sizeof(float) * config_.group_max_len * config_.hidden_size;
buffer_size += config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
buffer_size += config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
buffer_size += config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
buffer_size += config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
buffer_size += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size;
buffer_size += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size;
buffer_size += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size;
buffer_size += config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
buffer_size += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size;
buffer_size += sizeof(float) * config_.group_max_len * config_.hidden_size;
buffer_ = (uint8_t*)malloc(buffer_size);
}
uint64_t offset = 0;
s_input_fp32_ = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.hidden_size;
s_gate_input_ = (uint8_t*)(buffer_ + offset);
offset += config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
s_up_input_ = (uint8_t*)(buffer_ + offset);
offset += config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
std::vector<std::pair<void**, uint64_t>> s_mem_requests;
s_mem_requests.push_back({(void**)&s_input_fp32_, sizeof(float) * config_.hidden_size});
s_mem_requests.push_back({(void**)&s_gate_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
s_mem_requests.push_back({(void**)&s_up_input_, config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
s_gate_output_.resize(config_.routed_expert_num);
s_up_output_.resize(config_.routed_expert_num);
s_intermediate_fp32_.resize(config_.routed_expert_num);
s_down_input_.resize(config_.routed_expert_num);
s_down_output_.resize(config_.routed_expert_num);
for (int i = 0; i < config_.routed_expert_num; i++) {
s_gate_output_[i] = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.intermediate_size;
s_up_output_[i] = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.intermediate_size;
s_intermediate_fp32_[i] = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.intermediate_size;
s_down_input_[i] = (uint8_t*)(buffer_ + offset);
offset += config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
s_down_output_[i] = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.hidden_size;
s_mem_requests.push_back({(void**)&s_gate_output_[i], sizeof(float) * config_.intermediate_size});
s_mem_requests.push_back({(void**)&s_up_output_[i], sizeof(float) * config_.intermediate_size});
s_mem_requests.push_back({(void**)&s_intermediate_fp32_[i], sizeof(float) * config_.intermediate_size});
s_mem_requests.push_back({(void**)&s_down_input_[i], config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});
s_mem_requests.push_back({(void**)&s_down_output_[i], sizeof(float) * config_.hidden_size});
}
s_output_fp32_ = (float*)(buffer_ + offset);
s_mem_requests.push_back({(void**)&s_output_fp32_, sizeof(float) * config_.hidden_size});
shared_mem_buffer.alloc(this, s_mem_requests);
offset = 0;
std::vector<std::pair<void**, uint64_t>> m_mem_requests;
m_input_fp32_.resize(config_.group_max_len);
m_gate_input_.resize(config_.group_max_len);
m_up_input_.resize(config_.group_max_len);
for (int i = 0; i < config_.group_max_len; i++) {
m_input_fp32_[i] = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.hidden_size;
m_gate_input_[i] = (uint8_t*)(buffer_ + offset);
offset += config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
m_up_input_[i] = (uint8_t*)(buffer_ + offset);
offset += config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
m_mem_requests.push_back({(void**)&m_input_fp32_[i], sizeof(float) * config_.hidden_size});
m_mem_requests.push_back({(void**)&m_gate_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_up_input_[i], config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
}
m_local_gate_input_ = (uint8_t*)(buffer_ + offset);
offset += config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type);
m_local_up_input_ = (uint8_t*)(buffer_ + offset);
offset += config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type);
m_local_gate_output_ = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size;
m_local_up_output_ = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size;
m_local_intermediate_fp32_ = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size;
m_local_down_input_ = (uint8_t*)(buffer_ + offset);
offset += config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type);
m_local_down_output_ = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size;
m_mem_requests.push_back({(void**)&m_local_gate_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.gate_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_local_up_input_, config_.routed_expert_num * config_.group_max_len * config_.hidden_size * ggml_type_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.up_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_local_gate_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});
m_mem_requests.push_back({(void**)&m_local_up_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});
m_mem_requests.push_back({(void**)&m_local_intermediate_fp32_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.intermediate_size});
m_mem_requests.push_back({(void**)&m_local_down_input_, config_.routed_expert_num * config_.group_max_len * config_.intermediate_size * ggml_type_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type) / ggml_blck_size(ggml_internal_get_type_traits(config_.down_type).vec_dot_type)});
m_mem_requests.push_back({(void**)&m_local_down_output_, sizeof(float) * config_.routed_expert_num * config_.group_max_len * config_.hidden_size});
m_output_fp32_.resize(config_.group_max_len);
for (int i = 0; i < config_.group_max_len; i++) {
m_output_fp32_[i] = (float*)(buffer_ + offset);
offset += sizeof(float) * config_.hidden_size;
m_mem_requests.push_back({(void**)&m_output_fp32_[i], sizeof(float) * config_.hidden_size});
}
shared_mem_buffer.alloc(this, m_mem_requests);
m_local_pos_.resize(config_.group_max_len);
for (int i = 0; i < config_.group_max_len; i++) {
......@@ -107,6 +72,10 @@ MOE::MOE(MOEConfig config) {
m_local_down_output_ptr_.resize(config_.expert_num);
}
MOE::~MOE() {
shared_mem_buffer.dealloc(this);
}
void MOE::warm_up(Backend* backend) {
std::vector<float> input_fp32(config_.hidden_size);
std::vector<uint8_t> input(config_.hidden_size * ggml_type_size(config_.hidden_type) / ggml_blck_size(config_.hidden_type));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment