Unverified Commit 90969fb3 authored by LukasBluebaum's avatar LukasBluebaum Committed by GitHub
Browse files

[Kernel] Add more dtype support for GGUF dequantization (#15879)


Signed-off-by: default avatarlukas.bluebaum <lukas.bluebaum@aleph-alpha.com>
parent 101f1481
...@@ -145,7 +145,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); ...@@ -145,7 +145,8 @@ torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm);
#endif #endif
torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m,
int64_t n); int64_t n,
std::optional<at::ScalarType> const& dtype);
torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X,
int64_t type, int64_t row); int64_t type, int64_t row);
......
...@@ -94,8 +94,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ ...@@ -94,8 +94,8 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
dfloat2 v; dfloat2 v;
dequantize_kernel(vx, ib, iqs, v); dequantize_kernel(vx, ib, iqs, v);
y[iybs + iqs + 0] = v.x; y[iybs + iqs + 0] = convert_from_half<dst_t>(v.x);
y[iybs + iqs + y_offset] = v.y; y[iybs + iqs + y_offset] = convert_from_half<dst_t>(v.y);
} }
template<typename dst_t> template<typename dst_t>
...@@ -114,10 +114,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t ...@@ -114,10 +114,10 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, dst_t
half dall = __low2half(x[i].dm); half dall = __low2half(x[i].dm);
half dmin = __high2half(x[i].dm); half dmin = __high2half(x[i].dm);
y[l+ 0] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))); y[l+ 0] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+0] & 0xF) * ((q >> 0) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+0] >> 4))));
y[l+32] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))); y[l+32] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+2] & 0xF) * ((q >> 2) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+2] >> 4))));
y[l+64] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))); y[l+64] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+4] & 0xF) * ((q >> 4) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+4] >> 4))));
y[l+96] = __hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))); y[l+96] = convert_from_half<dst_t>(__hsub(__hmul(dall, __int2half_rn((x[i].scales[is+6] & 0xF) * ((q >> 6) & 3))), __hmul(dmin, __int2half_rn(x[i].scales[is+6] >> 4))));
} }
template<typename dst_t> template<typename dst_t>
...@@ -148,7 +148,9 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t ...@@ -148,7 +148,9 @@ static __global__ void dequantize_block_q3_K(const void * __restrict__ vx, dst_t
const uint8_t * q = x[i].qs + 32*n; const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask; const uint8_t * hm = x[i].hmask;
for (int l = l0; l < l0+4; ++l) y[l] = __hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))); for (int l = l0; l < l0+4; ++l) {
y[l] = convert_from_half<dst_t>(__hmul(dl, __int2half_rn((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4))));
}
} }
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
...@@ -188,8 +190,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t ...@@ -188,8 +190,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, dst_t
const half d2 = __hmul(dall, __int2half_rn(sc)); const half d2 = __hmul(dall, __int2half_rn(sc));
const half m2 = __hmul(dmin, __int2half_rn(m)); const half m2 = __hmul(dmin, __int2half_rn(m));
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
y[l + 0] = __hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1); y[l + 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn(q[l] & 0xF)), m1));
y[l +32] = __hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2); y[l +32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn(q[l] >> 4)), m2));
} }
} }
...@@ -220,11 +222,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t ...@@ -220,11 +222,11 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, dst_t
const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m)); const half d2 = __hmul(dall, __int2half_rn(sc)); const half m2 = __hmul(dmin, __int2half_rn(m));
uint8_t hm = 1 << (2*il); uint8_t hm = 1 << (2*il);
y[ 0] = __hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1); y[ 0] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0))), m1));
y[ 1] = __hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1); y[ 1] = convert_from_half<dst_t>(__hsub(__hmul(d1, __int2half_rn((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0))), m1));
hm <<= 1; hm <<= 1;
y[32] = __hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2); y[32] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[0] >> 4) + (qh[0] & hm ? 16 : 0))), m2));
y[33] = __hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2); y[33] = convert_from_half<dst_t>(__hsub(__hmul(d2, __int2half_rn((ql[1] >> 4) + (qh[1] & hm ? 16 : 0))), m2));
} }
template<typename dst_t> template<typename dst_t>
...@@ -247,10 +249,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t ...@@ -247,10 +249,10 @@ static __global__ void dequantize_block_q6_K(const void * __restrict__ vx, dst_t
const uint8_t qh = x[i].qh[32*ip + il]; const uint8_t qh = x[i].qh[32*ip + il];
const int8_t * sc = x[i].scales + is; const int8_t * sc = x[i].scales + is;
y[ 0] = __hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); y[ 0] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))));
y[32] = __hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); y[32] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))));
y[64] = __hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); y[64] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))));
y[96] = __hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); y[96] = convert_from_half<dst_t>(__hmul(d, __int2half_rn(sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))));
} }
template<typename dst_t> template<typename dst_t>
...@@ -269,7 +271,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds ...@@ -269,7 +271,7 @@ static __global__ void dequantize_block_iq2_xxs(const void * __restrict__ vx, ds
const uint32_t aux32 = q2[2] | (q2[3] << 16); const uint32_t aux32 = q2[2] | (q2[3] << 16);
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f; const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
} }
template<typename dst_t> template<typename dst_t>
...@@ -286,7 +288,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst ...@@ -286,7 +288,7 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511)); const uint8_t * grid = (const uint8_t *)(iq2xs_grid + (q2[il] & 511));
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = ksigns_iq2xs[q2[il] >> 9]; const uint8_t signs = ksigns_iq2xs[q2[il] >> 9];
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
} }
...@@ -303,7 +305,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_ ...@@ -303,7 +305,7 @@ static __global__ void dequantize_block_iq2_s(const void * __restrict__ vx, dst_
const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300))); const uint8_t * grid = (const uint8_t *)(iq2s_grid + (x[i].qs[4*ib+il] | ((x[i].qh[ib] << (8-2*il)) & 0x300)));
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f; const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib] >> 4*(il/2)) & 0xf)) * 0.25f;
const uint8_t signs = x[i].qs[QK_K/8+4*ib+il]; const uint8_t signs = x[i].qs[QK_K/8+4*ib+il];
for (int j = 0; j < 8; ++j) y[j] = __float2half(d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f)); for (int j = 0; j < 8; ++j) y[j] = d * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
} }
template<typename dst_t> template<typename dst_t>
...@@ -324,8 +326,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds ...@@ -324,8 +326,8 @@ static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, ds
const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f; const float d = __half2float(x[i].d) * (0.5f + (aux32 >> 28)) * 0.5f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127]; const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f)); y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f)); y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
} }
} }
...@@ -345,8 +347,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_ ...@@ -345,8 +347,8 @@ static __global__ void dequantize_block_iq3_s(const void * __restrict__ vx, dst_
const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f; const float d = __half2float(x[i].d) * (0.5f + ((x[i].scales[ib/2] >> 4*(ib%2)) & 0xf)) * 0.5f;
const uint8_t signs = x[i].signs[4*ib + il]; const uint8_t signs = x[i].signs[4*ib + il];
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
y[j+0] = __float2half(d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f)); y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = __float2half(d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f)); y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
} }
} }
...@@ -367,7 +369,7 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_ ...@@ -367,7 +369,7 @@ static __global__ void dequantize_block_iq1_s(const void * __restrict__ vx, dst_
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f; grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) { for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta)); y[j] = d * (q[j] + delta);
} }
} }
...@@ -392,7 +394,7 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_ ...@@ -392,7 +394,7 @@ static __global__ void dequantize_block_iq1_m(const void * __restrict__ vx, dst_
grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f; grid32[1] = (grid32[0] >> 4) & 0x0f0f0f0f;
grid32[0] &= 0x0f0f0f0f; grid32[0] &= 0x0f0f0f0f;
for (int j = 0; j < 8; ++j) { for (int j = 0; j < 8; ++j) {
y[j] = __float2half(d * (q[j] + delta)); y[j] = d * (q[j] + delta);
} }
} }
...@@ -409,8 +411,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst ...@@ -409,8 +411,8 @@ static __global__ void dequantize_block_iq4_nl(const void * __restrict__ vx, dst
const uint8_t * q4 = x[ib].qs + 4*il; const uint8_t * q4 = x[ib].qs + 4*il;
const float d = __half2float(x[ib].d); const float d = __half2float(x[ib].d);
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]); y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]); y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
} }
} }
...@@ -427,8 +429,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst ...@@ -427,8 +429,8 @@ static __global__ void dequantize_block_iq4_xs(const void * __restrict__ vx, dst
const uint8_t * q4 = x[i].qs + 16*ib + 4*il; const uint8_t * q4 = x[i].qs + 16*ib + 4*il;
const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32); const float d = __half2float(x[i].d) * ((((x[i].scales_l[ib/2] >> 4*(ib%2)) & 0xf) | (((x[i].scales_h >> 2*ib) & 3) << 4)) - 32);
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
y[j+ 0] = __float2half(d * kvalues_iq4nl[q4[j] & 0xf]); y[j+ 0] = d * kvalues_iq4nl[q4[j] & 0xf];
y[j+16] = __float2half(d * kvalues_iq4nl[q4[j] >> 4]); y[j+16] = d * kvalues_iq4nl[q4[j] >> 4];
} }
} }
...@@ -522,7 +524,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k, ...@@ -522,7 +524,8 @@ static void dequantize_row_iq4_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y); dequantize_block_iq4_xs<<<nb, 32, 0, stream>>>(vx, y);
} }
static to_fp16_cuda_t ggml_get_to_fp16_cuda(int64_t type) { template<typename dst_t>
static to_cuda_ggml_t<dst_t> ggml_get_to_cuda(int64_t type) {
switch (type) { switch (type) {
case 2: case 2:
return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>; return dequantize_block_cuda<QK4_0, QR4_0, dequantize_q4_0>;
......
...@@ -1063,7 +1063,8 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, - ...@@ -1063,7 +1063,8 @@ static const __device__ int8_t kvalues_iq4nl[16] = {-127, -104, -83, -65, -49, -
typedef half dfloat; // dequantize float typedef half dfloat; // dequantize float
typedef half2 dfloat2; typedef half2 dfloat2;
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v); typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
typedef void (*to_fp16_cuda_t)(const void * __restrict__ x, dfloat * __restrict__ y, int k, cudaStream_t stream); template<typename dst_t>
using to_cuda_ggml_t = void (*)(const void * __restrict__ x, dst_t * __restrict__ y, int k, cudaStream_t stream);
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs);
typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc); typedef void (*allocate_tiles_cuda_t)(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc);
typedef void (*load_tiles_cuda_t)( typedef void (*load_tiles_cuda_t)(
...@@ -1075,6 +1076,20 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)( ...@@ -1075,6 +1076,20 @@ typedef float (*vec_dot_q_mul_mat_cuda_t)(
// Utility function // Utility function
template<typename dst_t>
static __device__ __forceinline__ dst_t convert_from_half(half val) {
return val;
}
template<>
__device__ __forceinline__ c10::BFloat16 convert_from_half<c10::BFloat16>(half val) {
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return __float2bfloat16(__half2float(val));
#else
return __half2float(val);
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
}
#if defined(USE_ROCM) #if defined(USE_ROCM)
#ifndef __has_builtin #ifndef __has_builtin
......
...@@ -71,14 +71,19 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx, ...@@ -71,14 +71,19 @@ static void quantize_row_q8_1_cuda(const scalar_t* x, void* vy, const int kx,
} }
torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight torch::Tensor ggml_dequantize(torch::Tensor W, // quant weight
int64_t type, int64_t m, int64_t n) { int64_t type, int64_t m, int64_t n,
std::optional<at::ScalarType> const& dtype) {
const at::cuda::OptionalCUDAGuard device_guard(device_of(W)); const at::cuda::OptionalCUDAGuard device_guard(device_of(W));
auto options = auto dtype_ = dtype.value_or(torch::kFloat16);
torch::TensorOptions().dtype(torch::kFloat16).device(W.device()); auto options = torch::TensorOptions().dtype(dtype_).device(W.device());
at::Tensor DW = torch::empty({m, n}, options); at::Tensor DW = torch::empty({m, n}, options);
cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream(); cudaStream_t stream = at::cuda::getCurrentCUDAStream().stream();
const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(type);
to_fp16_cuda((void*)W.data_ptr(), (half*)DW.data_ptr(), m * n, stream); VLLM_DISPATCH_FLOATING_TYPES(DW.scalar_type(), "ggml_dequantize", [&] {
auto to_cuda = ggml_get_to_cuda<scalar_t>(type);
to_cuda((void*)W.data_ptr(), (scalar_t*)DW.data_ptr(), m * n, stream);
});
return DW; return DW;
} }
......
...@@ -295,7 +295,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ...@@ -295,7 +295,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
#endif #endif
// Dequantization for GGML. // Dequantization for GGML.
ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"); ops.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize); ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);
// mmvq kernel for GGML. // mmvq kernel for GGML.
......
...@@ -15,7 +15,8 @@ def test_ggml_opcheck(quant_type): ...@@ -15,7 +15,8 @@ def test_ggml_opcheck(quant_type):
qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8) qweight = torch.randint(0, 100, shape, device='cuda', dtype=torch.uint8)
m = qweight.shape[0] m = qweight.shape[0]
n = qweight.shape[1] // type_size * block_size n = qweight.shape[1] // type_size * block_size
opcheck(torch.ops._C.ggml_dequantize, (qweight, quant_type, m, n)) opcheck(torch.ops._C.ggml_dequantize,
(qweight, quant_type, m, n, torch.float16))
x = torch.rand((m, 512), device='cuda', dtype=torch.float16) x = torch.rand((m, 512), device='cuda', dtype=torch.float16)
opcheck(torch.ops._C.ggml_mul_mat_a8, opcheck(torch.ops._C.ggml_mul_mat_a8,
......
...@@ -65,7 +65,7 @@ QUANT_TYPES = [ ...@@ -65,7 +65,7 @@ QUANT_TYPES = [
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES) @pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("dtype", [torch.half]) @pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("quant_type", QUANT_TYPES) @pytest.mark.parametrize("quant_type", QUANT_TYPES)
@torch.inference_mode() @torch.inference_mode()
def test_dequantize(hidden_size: int, dtype: torch.dtype, def test_dequantize(hidden_size: int, dtype: torch.dtype,
...@@ -78,7 +78,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype, ...@@ -78,7 +78,7 @@ def test_dequantize(hidden_size: int, dtype: torch.dtype,
ref_output = torch.tensor(dequantize(tensor.data, quant_type), ref_output = torch.tensor(dequantize(tensor.data, quant_type),
device="cuda").to(dtype) device="cuda").to(dtype)
output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"), output = ops.ggml_dequantize(torch.tensor(tensor.data, device="cuda"),
quant_type, *list(shape)).to(dtype) quant_type, *list(shape), dtype)
torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2) torch.testing.assert_close(output, ref_output, atol=1e-2, rtol=4e-2)
......
...@@ -436,9 +436,12 @@ if hasattr(torch.ops._C, "allspark_w8a16_gemm"): ...@@ -436,9 +436,12 @@ if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
if hasattr(torch.ops._C, "ggml_dequantize"): if hasattr(torch.ops._C, "ggml_dequantize"):
@register_fake("_C::ggml_dequantize") @register_fake("_C::ggml_dequantize")
def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, def _ggml_dequantize_fake(
m: torch.SymInt, W: torch.Tensor,
n: torch.SymInt) -> torch.Tensor: quant_type: int,
m: torch.SymInt,
n: torch.SymInt,
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
return torch.empty((m, n), dtype=torch.float16, device=W.device) return torch.empty((m, n), dtype=torch.float16, device=W.device)
@register_fake("_C::ggml_mul_mat_vec_a8") @register_fake("_C::ggml_mul_mat_vec_a8")
...@@ -1097,9 +1100,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -1097,9 +1100,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf # gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int,
n: int) -> torch.Tensor: dtype: Optional[torch.dtype]) -> torch.Tensor:
return torch.ops._C.ggml_dequantize(W, quant_type, m, n) return torch.ops._C.ggml_dequantize(W, quant_type, m, n, dtype)
def ggml_mul_mat_vec_a8( def ggml_mul_mat_vec_a8(
......
...@@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor, ...@@ -117,7 +117,7 @@ def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
elif qweight_type in DEQUANT_TYPES: elif qweight_type in DEQUANT_TYPES:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type] block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size) shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape) weight = ops.ggml_dequantize(qweight, qweight_type, *shape, x.dtype)
y = x @ weight.T y = x @ weight.T
else: else:
# Raise an error if the quantization type is not supported. # Raise an error if the quantization type is not supported.
...@@ -377,7 +377,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod): ...@@ -377,7 +377,7 @@ class GGUFEmbeddingMethod(GGUFLinearMethod):
x_flat = x.flatten() x_flat = x.flatten()
quant = torch.index_select(qweight, dim=0, index=x_flat) quant = torch.index_select(qweight, dim=0, index=x_flat)
dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size, dequant = ops.ggml_dequantize(quant, qweight_type, hidden_size,
x_flat.shape[0]).to(self.params_dtype) x_flat.shape[0], self.params_dtype)
return dequant.view(*x.shape, hidden_size) return dequant.view(*x.shape, hidden_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment