Unverified Commit d5b5bc75 authored by Nicolas Patry's avatar Nicolas Patry Committed by GitHub
Browse files

feat(server): Add exllama GPTQ CUDA kernel support #553 (#666)

Just trying to get the integration tests to pass.


# What does this PR do?

<!--
Congratulations! You've made it this far! You're not quite done yet
though.

Once merged, your PR is going to appear in the release notes with the
title you set, so make sure it's a great title that fully reflects the
extent of your awesome contribution.

Then, please replace this with a description of the change and which
issue is fixed (if applicable). Please also include relevant motivation
and context. List any dependencies (if any) that are required for this
change.

Once you're done, someone will review your PR shortly (see the section
"Who can review?" below to tag some potential reviewers). They may
suggest changes to make the code even better. If no one reviewed your PR
after a week has passed, don't hesitate to post a new comment
@-mentioning the same persons---sometimes notifications get lost.
-->

<!-- Remove if not applicable -->

Fixes # (issue)


## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the
other checks if that's the case).
- [ ] Did you read the [contributor
guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests),
      Pull Request section?
- [ ] Was this discussed/approved via a Github issue or the
[forum](https://discuss.huggingface.co/)? Please add a link
      to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
Here are the
[documentation
guidelines](https://github.com/huggingface/transformers/tree/main/docs),
and
[here are tips on formatting
docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation

).
- [ ] Did you write any new necessary tests?


## Who can review?

Anyone in the community is free to review the PR once the tests have
passed. Feel free to tag
members/contributors who may be interested in your PR.

<!-- Your PR will be replied to more quickly if you can figure out the
right person to tag with @


@OlivierDehaene OR @Narsil

 -->

---------
Co-authored-by: default avatarFelix Marty <9808326+fxmarty@users.noreply.github.com>
parent bf94df3c
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <ATen/cuda/CUDAContext.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#include "util.cuh"
#include "tuning.h"
#include "cuda_buffers.cuh"
#include "cuda_func/q4_matrix.cuh"
#include "cuda_func/q4_matmul.cuh"
#include "cuda_func/column_remap.cuh"
// Check CUDA return code. We don't want to include Torch headers in the .cu files because parsing them adds almost a
// minute to the compile time on a 12900K. Also passing exceptions back to Python is super tricky, so in place of
// exceptions, CUDA functions return with a cudaError_t which we can parse and dump to the console.
void check_cuda(cudaError_t ret)
{
switch (ret)
{
case cudaSuccess:
break;
case cudaUnspecified:
printf(" **** Unspecified error\n");
TORCH_CHECK(false, "CUDA error");
break;
default:
printf(" **** CUDA error\n"); \
printf(" **** %s\n", cudaGetErrorString(ret)); \
TORCH_CHECK(false, "CUDA error"); \
break;
}
}
// Some decluttering macros
#define STRINGIFY_(__x) #__x
#define STRINGIFY(__x) STRINGIFY_(__x)
#define TORCH_CHECK_DTYPE(__x, __dtype) TORCH_CHECK((__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_DTYPE_OPT(__x, __dtype) TORCH_CHECK((__x).device().is_meta() || (__x).dtype() == torch::__dtype, #__x " is incorrect datatype, must be " #__dtype)
#define TORCH_CHECK_SHAPES(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPES_OPT(__x, __dim_x, __y, __dim_y, __scale_y) TORCH_CHECK((__x).device().is_meta() || (__x).size(__dim_x) == (__y).size(__dim_y) * __scale_y, #__x " and " #__y " have incompatible shapes")
#define TORCH_CHECK_SHAPE_MOD(__x, __dim_x, __mod) TORCH_CHECK((__x).size(__dim_x) % __mod == 0, #__x ".shape[" STRINGIFY(__dim_x) "] must be a multiple of " STRINGIFY(__mod))
#define TORCH_CHECK_DEVICE_INDEX(__index) \
do { \
TORCH_CHECK(__index >= 0, "no device index"); \
TORCH_CHECK(__index < CUDA_MAX_DEVICES, "invalid device index"); \
} while(0)
#define TORCH_CHECK_QUANT(__w, __w_scales, __w_zeros, __seq_g_idx, __x_map) \
do { \
TORCH_CHECK_DTYPE(__w, kInt); \
TORCH_CHECK_DTYPE(__w_scales, kHalf); \
TORCH_CHECK_DTYPE(__w_zeros, kInt); \
TORCH_CHECK_DTYPE_OPT(__seq_g_idx, kShort); \
TORCH_CHECK_DTYPE_OPT(__x_map, kInt); \
TORCH_CHECK_SHAPES_OPT(__seq_g_idx, 0, __w, 0, 2 * 8); \
TORCH_CHECK_SHAPES_OPT(__x_map, 0, __w, 0, 8); \
} while(0)
int get_groupsize(torch::Tensor w, torch::Tensor w_zeros)
{
int groupsize = w.size(0) * 8 / w_zeros.size(0);
TORCH_CHECK(groupsize * w_zeros.size(0) == w.size(0) * 8, "w.shape[-2] must be a multiple of zeros.shape[-2]")
return groupsize;
}
// Tuning parameters
ExLlamaTuning tuningParams;
void set_tuning_params
(
int matmul_recons_thd,
bool matmul_fused_remap,
bool matmul_no_half2
)
{
tuningParams.matmul_recons_thd = matmul_recons_thd;
tuningParams.matmul_fused_remap = matmul_fused_remap;
tuningParams.matmul_no_half2 = matmul_no_half2;
}
// Release all unmanaged objects allocated by the extension
void cleanup()
{
cleanup_buffers_cuda();
g_q4_free_matrices();
}
// Prepare buffers for forward pass
void prepare_buffers
(
torch::Device device,
torch::Tensor temp_state,
torch::Tensor temp_dq
)
{
int device_index = device.index();
TORCH_CHECK_DEVICE_INDEX(device_index);
const at::cuda::OptionalCUDAGuard device_guard(device);
prepare_buffers_cuda
(
device_index,
(half*) temp_state.data_ptr(),
(half*) temp_dq.data_ptr()
);
}
// Create Q4Matrix, return handle
uintptr_t make_q4
(
torch::Tensor qweight,
torch::Tensor qzeros,
torch::Tensor scales,
torch::Tensor g_idx,
int device
)
{
TORCH_CHECK_DTYPE(qweight, kInt);
TORCH_CHECK_DTYPE(qzeros, kInt);
TORCH_CHECK_DTYPE(scales, kHalf);
TORCH_CHECK_DTYPE_OPT(g_idx, kInt);
TORCH_CHECK_SHAPES(qweight, 1, qzeros, 1, 8);
TORCH_CHECK_SHAPES(scales, 1, qweight, 1, 1);
TORCH_CHECK_SHAPES(qzeros, 0, scales, 0, 1);
int width = qweight.size(1);
int height = qweight.size(0) * 8;
int groups = qzeros.size(0);
Q4Matrix* m = new Q4Matrix
(
height,
width,
groups,
(uint32_t*) qweight.data_ptr(),
(uint32_t*) qzeros.data_ptr(),
(half*) scales.data_ptr(),
g_idx.device().is_meta() ? NULL : (uint32_t*) g_idx.data_ptr(),
device
);
g_q4_keep_matrix(m);
return reinterpret_cast<uintptr_t> (m);
}
// Matmul half @ quant -> half
void q4_matmul
(
torch::Tensor x,
uintptr_t w,
torch::Tensor out
)
{
Q4Matrix* wm = reinterpret_cast<Q4Matrix*> (w);
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(out, kHalf);
TORCH_CHECK_SHAPES(x, 0, out, 0, 1);
TORCH_CHECK(wm->height == x.size(-1), "x and w have incompatible shapes")
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
int x_height = x.size(0);
if (tuningParams.matmul_recons_thd == 0 || x_height < tuningParams.matmul_recons_thd)
{
q4_matmul_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr()
);
}
else
{
q4_matmul_recons_cuda
(
&tuningParams,
(half*) x.data_ptr(),
x_height,
wm,
(half*) out.data_ptr(),
at::cuda::getCurrentCUDABlasHandle()
);
}
}
// Remap columns in half tensor
void column_remap
(
torch::Tensor x,
torch::Tensor x_new,
torch::Tensor x_map
)
{
TORCH_CHECK_DTYPE(x, kHalf);
TORCH_CHECK_DTYPE(x_new, kHalf);
TORCH_CHECK_DTYPE(x_map, kInt);
TORCH_CHECK_SHAPES(x_map, 0, x, 1, 1);
int height = x.size(0);
int width = x.size(1);
const at::cuda::OptionalCUDAGuard device_guard(device_of(x));
column_remap_cuda
(
(half*) x.data_ptr(),
(half*) x_new.data_ptr(),
height,
width,
(uint32_t*) x_map.data_ptr()
);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
{
m.def("set_tuning_params", &set_tuning_params, "set_tuning_params");
m.def("prepare_buffers", &prepare_buffers, "prepare_buffers");
m.def("cleanup", &cleanup, "cleanup");
m.def("make_q4", &make_q4, "make_q4");
m.def("q4_matmul", &q4_matmul, "q4_matmul");
}
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _matrix_cuh
#define _matrix_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
class MatrixView_half
{
public:
const half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half(const half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half2 item_half2half2(int row, int column) const { return __half2half2(data[row * width + column]); }
__device__ __forceinline__ const half* item_ptr(int row, int column) const { return &data[row * width + column]; }
};
class MatrixView_half_rw
{
public:
half* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_half_rw(half* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ half item(int row, int column) const { return data[row * width + column]; }
__device__ __forceinline__ half2 item_half2(int row, int column) const { return ((half2*)data)[(row * width + column) / 2]; }
__device__ __forceinline__ half* item_ptr(int row, int column) { return &data[row * width + column]; }
__device__ __forceinline__ void set(int row, int column, half value) { data[row * width + column] = value; }
__device__ __forceinline__ void set_half2(int row, int column, half2 value) { ((half2*)data)[(row * width + column) / 2] = value; }
};
class MatrixView_q4_row
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_row(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (column & 0x07) * 4;
return (data[row * width / 8 + column / 8] >> shift) & 0x0f;
}
};
class MatrixView_q4_column
{
public:
const uint32_t* data;
const int height;
const int width;
__device__ __forceinline__ MatrixView_q4_column(const uint32_t* data, const int height, const int width)
: data(data), height(height), width(width)
{ }
__device__ __forceinline__ int item(int row, int column) const
{
int shift = (row & 0x07) * 4;
return (data[row / 8 * width + column] >> shift) & 0x0f;
}
__device__ __forceinline__ uint32_t item_uint32_t(int row, int column) { return data[row / 8 * width + column]; }
__device__ __forceinline__ const uint32_t* item_uint32_ptr(int row, int column) { return &data[row / 8 * width + column]; }
};
// TODO: Rewrite all these dot product functions using functors or something, move to q4_matmul.cu
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale
__device__ __forceinline__ half2 dot_product_8
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half2* h_ptr = (const half2*) h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
// half2 v_01 = q4_table[v_zero - 1][(v_read ) & 0xff]; // (constant memory is too slow apparently)
// half2 v_23 = q4_table[v_zero - 1][(v_read >> 8) & 0xff];
// half2 v_45 = q4_table[v_zero - 1][(v_read >> 16) & 0xff];
// half2 v_67 = q4_table[v_zero - 1][(v_read >> 24) ];
half2 tmp = __hmul2(*h_ptr++, v_01);
tmp = __hfma2(*h_ptr++, v_23, tmp);
tmp = __hfma2(*h_ptr++, v_45, tmp);
tmp = __hfma2(*h_ptr++, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count
)
{
const half* h_ptr = h_.item_ptr(h_row, h_column);
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(*h_ptr++, v_0);
tmp = __hfma(*h_ptr++, v_1, tmp);
tmp = __hfma(*h_ptr++, v_2, tmp);
tmp = __hfma(*h_ptr++, v_3, tmp);
tmp = __hfma(*h_ptr++, v_4, tmp);
tmp = __hfma(*h_ptr++, v_5, tmp);
tmp = __hfma(*h_ptr++, v_6, tmp);
tmp = __hfma(*h_ptr++, v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
// Accumulated dot product of 8-element row vectors in h and quantized column vectors in v, constant zero/scale, with x_map
__device__ __forceinline__ half2 dot_product_8_x_map
(
const half2 acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half2 v_scale_2,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half2 result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half2 v_01 = __halves2half2(v_0, v_1);
half2 v_23 = __halves2half2(v_2, v_3);
half2 v_45 = __halves2half2(v_4, v_5);
half2 v_67 = __halves2half2(v_6, v_7);
half h_0 = h_ptr[*x_map_ptr++];
half h_1 = h_ptr[*x_map_ptr++];
half h_2 = h_ptr[*x_map_ptr++];
half h_3 = h_ptr[*x_map_ptr++];
half h_4 = h_ptr[*x_map_ptr++];
half h_5 = h_ptr[*x_map_ptr++];
half h_6 = h_ptr[*x_map_ptr++];
half h_7 = h_ptr[*x_map_ptr++];
half2 h_01 = __halves2half2(h_0, h_1);
half2 h_23 = __halves2half2(h_2, h_3);
half2 h_45 = __halves2half2(h_4, h_5);
half2 h_67 = __halves2half2(h_6, h_7);
half2 tmp = __hmul2(h_01, v_01);
tmp = __hfma2(h_23, v_23, tmp);
tmp = __hfma2(h_45, v_45, tmp);
tmp = __hfma2(h_67, v_67, tmp);
result = __hfma2(v_scale_2, tmp, result);
}
return result;
}
__device__ __forceinline__ half dot_product_8_x_map_h
(
const half acc,
MatrixView_half& h_,
const int h_row,
const int h_column, // divisible by 8
MatrixView_q4_column& v_,
const int v_row, // divisible by 8
const int v_column,
const half v_scale,
const uint32_t v_zero, // + 1 (!!)
const int count,
const uint32_t* x_map
)
{
const half* h_ptr = h_.item_ptr(h_row, 0);
const uint32_t* x_map_ptr = x_map + h_column;
const uint32_t* v_ptr = (const uint32_t*) v_.item_uint32_ptr(v_row, v_column);
half result = acc;
for (int i = 0; i < count; i++)
{
uint32_t v_read = *v_ptr; v_ptr += v_.width;
half v_0 = __int2half_rn((int)((v_read ) & 0x0f) - v_zero);
half v_1 = __int2half_rn((int)((v_read >> 4) & 0x0f) - v_zero);
half v_2 = __int2half_rn((int)((v_read >> 8) & 0x0f) - v_zero);
half v_3 = __int2half_rn((int)((v_read >> 12) & 0x0f) - v_zero);
half v_4 = __int2half_rn((int)((v_read >> 16) & 0x0f) - v_zero);
half v_5 = __int2half_rn((int)((v_read >> 20) & 0x0f) - v_zero);
half v_6 = __int2half_rn((int)((v_read >> 24) & 0x0f) - v_zero);
half v_7 = __int2half_rn((int)((v_read >> 28) ) - v_zero);
half tmp = __hmul(h_ptr[*x_map_ptr++], v_0);
tmp = __hfma(h_ptr[*x_map_ptr++], v_1, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_2, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_3, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_4, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_5, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_6, tmp);
tmp = __hfma(h_ptr[*x_map_ptr++], v_7, tmp);
result = __hfma(v_scale, tmp, result);
}
return result;
}
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _tuning_h
#define _tuning_h
struct ExLlamaTuning
{
int matmul_recons_thd;
bool matmul_fused_remap;
bool matmul_no_half2;
};
#endif
// Adapted from turboderp exllama: https://github.com/turboderp/exllama
#ifndef _util_cuh
#define _util_cuh
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cstdint>
#include <cstdio>
#define cudaUnspecified cudaErrorApiFailureBase
// React to failure on return code != cudaSuccess
#define _cuda_check(fn) \
do { \
{_cuda_err = fn;} \
if (_cuda_err != cudaSuccess) goto _cuda_fail; \
} while(false)
// React to failure on return code == 0
#define _alloc_check(fn) \
do { \
if (!(fn)) { _cuda_err = cudaUnspecified; goto _cuda_fail; } \
else _cuda_err = cudaSuccess; \
} while(false)
#endif
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
setup(
name="exllama_kernels",
ext_modules=[
CUDAExtension(
name="exllama_kernels",
sources=[
"exllama_kernels/exllama_ext.cpp",
"exllama_kernels/cuda_buffers.cu",
"exllama_kernels/cuda_func/column_remap.cu",
"exllama_kernels/cuda_func/q4_matmul.cu",
"exllama_kernels/cuda_func/q4_matrix.cu"
],
)
],
cmdclass={"build_ext": BuildExtension},
)
......@@ -383,7 +383,6 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, config, weights):
super().__init__()
self.config = config
process_group = weights.process_group
self.tp_rank = process_group.rank()
......
......@@ -20,7 +20,6 @@ from text_generation_server.utils.layers import (
)
from safetensors import SafetensorError
def load_multi_mqa(
config, prefix: str, weights, bias: bool, head_size, num_heads, hidden_size
):
......@@ -50,6 +49,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
qweight = torch.cat([q_tensor, kv_tensor], dim=1)
qweight = qweight.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.scales")
shape = slice_.get_shape()
......@@ -60,6 +60,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size :]
scales = torch.cat([q_tensor, kv_tensor], dim=1)
scales = scales.to(device=weights.device)
slice_ = weights._get_slice(f"{prefix}.c_attn.qzeros")
shape = slice_.get_shape()
......@@ -70,21 +71,15 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[:, start:stop]
kv_tensor = slice_[:, -2 * head_size * 4 // 32 :]
qzeros = torch.cat([q_tensor, kv_tensor], dim=1)
qzeros = qzeros.to(device=weights.device)
g_idx = weights.get_tensor(f"{prefix}.c_attn.g_idx")
try:
bits = weights.get_tensor("gptq_bits").item()
groupsize = weights.get_tensor("gptq_groupsize").item()
except SafetensorError as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
g_idx = g_idx.to(device=weights.device)
bits, groupsize = weights._get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
from text_generation_server.utils.layers import HAS_EXLLAMA
use_exllama = HAS_EXLLAMA
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
if bias:
slice_ = weights._get_slice(f"{prefix}.c_attn.bias")
......@@ -97,6 +92,7 @@ def _load_multi_mqa_gptq(
q_tensor = slice_[start:stop]
kv_tensor = slice_[-2 * head_size :]
bias = torch.cat([q_tensor, kv_tensor], dim=0)
bias = bias.to(device=weights.device)
return TensorParallelColumnLinear(get_linear(weight, bias, config.quantize))
else:
......@@ -361,7 +357,6 @@ class Block(nn.Module):
max_s,
):
hidden_states, residual = self.ln_1(hidden_states, residual)
hidden_states = self.attn(
hidden_states,
cu_seqlen_prefill,
......
......@@ -6,7 +6,6 @@ import torch.distributed
import numpy as np
from dataclasses import dataclass
from loguru import logger
from opentelemetry import trace
from transformers import PreTrainedTokenizerBase
from typing import Optional, Tuple, List, Type, Union, Dict
......
......@@ -3,14 +3,13 @@ import torch
from abc import ABC, abstractmethod
from typing import List, Tuple, Optional, TypeVar, Type
from transformers import PreTrainedTokenizerBase
from transformers import PreTrainedTokenizerBase, PretrainedConfig
from text_generation_server.models.types import Batch, GeneratedText
from text_generation_server.pb.generate_pb2 import InfoResponse
B = TypeVar("B", bound=Batch)
class Model(ABC):
def __init__(
self,
......
......@@ -16,6 +16,7 @@ from text_generation_server.pb import generate_pb2_grpc, generate_pb2
from text_generation_server.tracing import UDSOpenTelemetryAioServerInterceptor
class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str]):
self.cache = cache
......@@ -140,6 +141,16 @@ def serve(
logger.exception("Error when initializing model")
raise
if quantize == "gptq":
try:
# When using GPTQ, Exllama kernels need some global kernels
# For which we have the finale shapes only after the model has loaded
# This will allocate those buffers.
from text_generation_server.utils.gptq.exllama import create_exllama_buffers
create_exllama_buffers()
except ImportError:
pass
server = aio.server(
interceptors=[
ExceptionInterceptor(),
......
import torch
from exllama_kernels import make_q4, q4_matmul, prepare_buffers, set_tuning_params
# Dummy tensor to pass instead of g_idx since there is no way to pass "None" to a C++ extension
none_tensor = torch.empty((1, 1), device = "meta")
def ext_make_q4(qweight, qzeros, scales, g_idx, device):
"""Construct Q4Matrix, return handle"""
return make_q4(qweight,
qzeros,
scales,
g_idx if g_idx is not None else none_tensor,
device)
def ext_q4_matmul(x, q4, q4_width):
"""Matrix multiplication, returns x @ q4"""
outshape = x.shape[:-1] + (q4_width,)
x = x.view(-1, x.shape[-1])
output = torch.empty((x.shape[0], q4_width), dtype = torch.float16, device = x.device)
q4_matmul(x, q4, output)
return output.view(outshape)
MAX_DQ = 1
MAX_INNER = 1
ACT_ORDER = False
DEVICE = None
TEMP_STATE = None
TEMP_DQ = None
def create_exllama_buffers():
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE, TEMP_STATE, TEMP_DQ
if ACT_ORDER:
# TODO: this should be set to rust side `max_total_tokens`, but TGI
# does not offer an API to expose this variable to python, as this variable
# is handled by the client but it appears the model is initialized by the server.
# An alternative could be to initialize the buffers during warmup.
# Dummy
max_total_tokens = 2048
else:
max_total_tokens = 1
# This temp_state buffer is required to reorder X in the act-order case.
temp_state = torch.zeros((max_total_tokens, MAX_INNER), dtype=torch.float16, device=DEVICE)
temp_dq = torch.zeros((1, MAX_DQ), dtype=torch.float16, device=DEVICE)
# This temp_dq buffer is required to dequantize weights when using cuBLAS, typically for the prefill.
prepare_buffers(DEVICE, temp_state, temp_dq)
matmul_recons_thd = 8
matmul_fused_remap = False
matmul_no_half2 = False
set_tuning_params(matmul_recons_thd, matmul_fused_remap, matmul_no_half2)
TEMP_STATE, TEMP_DQ = temp_state, temp_dq
class Ex4bitLinear:
"""Linear layer implementation with per-group 4-bit quantization of the weights"""
def __init__(self, qweight, qzeros, scales, g_idx, bias, bits, groupsize):
global MAX_DQ, MAX_INNER, ACT_ORDER, DEVICE
assert bits == 4
self.device = qweight.device
self.qweight = qweight
self.qzeros = qzeros
self.scales = scales
self.g_idx = g_idx.cpu() if g_idx is not None else None
self.bias = bias if bias is not None else None
if self.g_idx is not None and ((self.g_idx == 0).all() or torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32))):
self.empty_g_idx = True
self.g_idx = None
assert self.device.type == "cuda"
assert self.device.index is not None
self.q4 = ext_make_q4(
self.qweight,
self.qzeros,
self.scales,
self.g_idx,
self.device.index
)
self.height = qweight.shape[0] * 8
self.width = qweight.shape[1]
# Infer groupsize from height of qzeros
self.groupsize = None
if self.qzeros.shape[0] > 1:
self.groupsize = (self.qweight.shape[0] * 8) // (self.qzeros.shape[0])
if self.groupsize is not None:
assert groupsize == self.groupsize
# Handle act-order matrix
if self.g_idx is not None:
if self.groupsize is None: raise ValueError("Found group index but no groupsize. What do?")
self.act_order = True
else:
self.act_order = False
DEVICE = self.qweight.device
MAX_DQ = max(MAX_DQ, self.qweight.numel() * 8)
if self.act_order:
MAX_INNER = max(MAX_INNER, self.height, self.width)
ACT_ORDER = True
def forward(self, x):
out = ext_q4_matmul(x, self.q4, self.width)
if self.bias is not None:
out.add_(self.bias)
return out
import os
import torch
import torch.distributed
......@@ -16,7 +17,15 @@ except ImportError:
from accelerate import init_empty_weights
from text_generation_server.utils.gptq.quant_linear import QuantLinear
HAS_EXLLAMA = True
if os.getenv("DISABLE_EXLLAMA") == "True":
HAS_EXLLAMA=False
try:
from text_generation_server.utils.gptq.exllama import Ex4bitLinear
except ImportError:
HAS_EXLLAMA = False
from typing import Optional
# Monkey patching
@classmethod
......@@ -144,21 +153,24 @@ def get_linear(weight, bias, quantize):
linear.bias = nn.Parameter(bias)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize = weight
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
except Exception:
raise NotImplementedError(
f"The passed weight is not `gptq` compatible, loader needs to be updated."
)
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
bias,
bits,
groupsize,
)
if use_exllama:
linear = Ex4bitLinear(qweight, qzeros, scales, g_idx, bias, bits, groupsize)
else:
linear = QuantLinear(
qweight,
qzeros,
scales,
g_idx,
bias,
bits,
groupsize,
)
else:
raise NotImplementedError(f"Quantization `{quantize}` is not implemented yet.")
return linear
......
from pathlib import Path
from typing import List, Dict, Optional
from typing import List, Dict, Optional, Tuple
from safetensors import safe_open, SafetensorError
import torch
from loguru import logger
class Weights:
......@@ -127,18 +128,8 @@ class Weights:
torch.testing.assert_close(w2, w[0])
g_idx = w[0]
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
bits, groupsize = self._get_gptq_qparams()
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, False)
else:
w = [self.get_sharded(f"{p}.weight", dim=0) for p in prefixes]
weight = torch.cat(w, dim=dim)
......@@ -146,29 +137,74 @@ class Weights:
def get_multi_weights_row(self, prefix: str, quantize: str):
if quantize == "gptq":
use_exllama = True
bits, groupsize = self._get_gptq_qparams()
if bits != 4:
use_exllama = False
if self.process_group.size() > 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
if g_idx is not None:
if not torch.equal(g_idx.cpu(), torch.tensor([i // groupsize for i in range(g_idx.shape[0])], dtype=torch.int32)) and not (g_idx == 0).all():
# Exllama implementation does not support row tensor parallelism with act-order, as
# it would require to reorder input activations that are split unto several GPUs
use_exllama = False
try:
qweight = self.get_sharded(f"{prefix}.qweight", dim=0)
except RuntimeError:
raise RuntimeError(
"Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`"
)
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
weight = (qweight, qzeros, scales, g_idx, bits, groupsize)
raise RuntimeError("Cannot load `gptq` weight, make sure the model is already quantized, or quantize it with `text-generation-server quantize ORIGINAL_MODEL_ID NEW_MODEL_ID`")
from text_generation_server.utils.layers import HAS_EXLLAMA
if use_exllama:
if not HAS_EXLLAMA:
logger.warning("Exllama GPTQ cuda kernels (which are faster) could have been used, but are not currently installed, try using BUILD_EXTENSIONS=True")
use_exllama = False
else:
logger.info("Using exllama kernels")
if use_exllama:
if groupsize >= 0:
# Exllama reorders the weights in advance and the activations on the fly, thus
# the scales and zero-points do not need to be reordered.
qzeros = self.get_sharded(f"{prefix}.qzeros", dim=0)
scales = self.get_sharded(f"{prefix}.scales", dim=0)
else:
raise RuntimeError("Using exllama GPTQ kernel with groupsize<1 is not supported")
# qzeros = self.get_tensor(f"{prefix}.qzeros")
# scales = self.get_tensor(f"{prefix}.scales")
# For tp > 1, at this point we know we do not use act-order
if self.process_group.size() == 1:
g_idx = self.get_tensor(f"{prefix}.g_idx")
else:
g_idx = None
else:
# The triton kernel reorders the scales/zero points instead of the weight/activation.
# Thus, each rank needs the full qzeros/scales.
qzeros = self.get_tensor(f"{prefix}.qzeros")
scales = self.get_tensor(f"{prefix}.scales")
g_idx = self.get_sharded(f"{prefix}.g_idx", dim=0)
weight = (qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama)
else:
weight = self.get_sharded(f"{prefix}.weight", dim=1)
return weight
def _get_gptq_qparams(self) -> Tuple[int, int]:
try:
bits = self.get_tensor("gptq_bits").item()
groupsize = self.get_tensor("gptq_groupsize").item()
except (SafetensorError, RuntimeError) as e:
try:
import os
bits = int(os.getenv("GPTQ_BITS"))
groupsize = int(os.getenv("GPTQ_GROUPSIZE"))
except Exception:
raise e
return bits, groupsize
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment