Commit 386c53bb authored by xuxzh1's avatar xuxzh1 🎱
Browse files

update

parent 4667452a
...@@ -24,432 +24,427 @@ ...@@ -24,432 +24,427 @@
* SOFTWARE. * SOFTWARE.
*/ */
#include "mmvq.cuh" #include "mmvq.cuh"
#include "vecdotq.cuh" #include "vecdotq.cuh"
typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs);
static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 :
type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 :
type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 :
type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 :
type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 :
type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 :
type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 :
type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 :
type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 :
type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 :
type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 :
type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 :
type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 :
type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 :
type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 :
type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 :
type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 :
type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 :
type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 :
nullptr; nullptr;
} }
static constexpr __device__ int get_vdr_mmvq(ggml_type type) { static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ :
type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ :
type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ :
type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ :
type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ :
type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ :
type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ :
type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ :
type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ :
type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ : type == GGML_TYPE_IQ2_XXS ? VDR_IQ2_XXS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ : type == GGML_TYPE_IQ2_XS ? VDR_IQ2_XS_Q8_1_MMVQ :
type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ : type == GGML_TYPE_IQ2_S ? VDR_IQ2_S_Q8_1_MMVQ :
type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ : type == GGML_TYPE_IQ3_XXS ? VDR_IQ3_XXS_Q8_1_MMVQ :
type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ : type == GGML_TYPE_IQ3_S ? VDR_IQ3_S_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ : type == GGML_TYPE_IQ4_NL ? VDR_IQ4_NL_Q8_1_MMVQ :
type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ : type == GGML_TYPE_IQ4_XS ? VDR_IQ4_XS_Q8_1_MMVQ :
1; 1;
} }
template <ggml_type type, int ncols_y> template <ggml_type type, int ncols_y>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) // #if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below // tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*16, 1) //__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) // #endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int vdr = get_vdr_mmvq(type);
constexpr int vdr = get_vdr_mmvq(type);
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) constexpr int nwarps = 1;
constexpr int nwarps = 1; constexpr int rows_per_cuda_block = 1;
constexpr int rows_per_cuda_block = 1; #else
#else constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int nwarps = ncols_y <= 4 ? 4 : 2; constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 1;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2; #endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; const int row0 = rows_per_cuda_block*blockIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x; const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1;
const int blocks_per_col_y = nrows_y / QK8_1; constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
// partial sum for each thread
// partial sum for each thread float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
const block_q8_1 * y = (const block_q8_1 *) vy;
const block_q8_1 * y = (const block_q8_1 *) vy; for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
// x block quant index when casting the quants to int
// x block quant index when casting the quants to int const int kqs = vdr * (tid % (qi/vdr));
const int kqs = vdr * (tid % (qi/vdr));
#pragma unroll ncols_y
#pragma unroll ncols_y for (int j = 0; j < ncols_y; ++j) {
for (int j = 0; j < ncols_y; ++j) { #pragma unroll rows_per_cuda_block
#pragma unroll rows_per_cuda_block for (int i = 0; i < rows_per_cuda_block; ++i) {
for (int i = 0; i < rows_per_cuda_block; ++i) { tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs);
//tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); }
atomicAdd(&tmp[j][i], vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs)); }
} }
}
} __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
if (threadIdx.y > 0) {
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE]; #pragma unroll ncols_y
if (threadIdx.y > 0) { for (int j = 0; j < ncols_y; ++j) {
#pragma unroll ncols_y #pragma unroll rows_per_cuda_block
for (int j = 0; j < ncols_y; ++j) { for (int i = 0; i < rows_per_cuda_block; ++i) {
#pragma unroll rows_per_cuda_block tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
for (int i = 0; i < rows_per_cuda_block; ++i) { }
//tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; }
atomicExch(&tmp_shared[threadIdx.y-1][j][i][threadIdx.x], tmp[j][i]); }
} __syncthreads();
} if (threadIdx.y > 0) {
} return;
__syncthreads(); }
if (threadIdx.y > 0) {
return; // sum up partial sums and write back result
} #pragma unroll ncols_y
for (int j = 0; j < ncols_y; ++j) {
// sum up partial sums and write back result #pragma unroll rows_per_cuda_block
#pragma unroll ncols_y for (int i = 0; i < rows_per_cuda_block; ++i) {
for (int j = 0; j < ncols_y; ++j) { #pragma unroll
#pragma unroll rows_per_cuda_block for (int l = 0; l < nwarps-1; ++l) {
for (int i = 0; i < rows_per_cuda_block; ++i) { tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
#pragma unroll }
for (int l = 0; l < nwarps-1; ++l) { tmp[j][i] = warp_reduce_sum(tmp[j][i]);
//tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; }
atomicAdd(&tmp[j][i], tmp_shared[l][j][i][threadIdx.x]);
} if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
//tmp[j][i] = warp_reduce_sum(tmp[j][i]); dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
atomicExch(&tmp[j][i], warp_reduce_sum(tmp[j][i])); }
} }
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
//dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x]; template <ggml_type type>
atomicExch(&dst[j*nrows_dst + row0 + threadIdx.x], tmp[j][threadIdx.x]); static void mul_mat_vec_q_cuda(
} const void * vx, const void * vy, float * dst,
} const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
}
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
template <ggml_type type> GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
static void mul_mat_vec_q_cuda(
const void * vx, const void * vy, float * dst, int id = ggml_cuda_get_device();
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
int64_t nwarps = 1;
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); int64_t rows_per_cuda_block = 1;
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
if (ggml_cuda_info().devices[id].cc < 1001030) { // NVIDIA and AMD older than RDNA2 but not CDNA
int id = ggml_cuda_get_device(); switch(ncols_y) {
case 1:
int64_t nwarps = 1; nwarps = 4;
int64_t rows_per_cuda_block = 1; rows_per_cuda_block = 1;
break;
if (ggml_cuda_info().devices[id].cc < 1001030) { // NVIDIA and AMD older than RDNA2 but not CDNA case 2:
switch(ncols_y) { case 3:
case 1: case 4:
nwarps = 4; nwarps = 4;
rows_per_cuda_block = 1; rows_per_cuda_block = 1;
break; break;
case 2: case 5:
case 3: case 6:
case 4: case 7:
nwarps = 4; case 8:
rows_per_cuda_block = 2; nwarps = 2;
break; rows_per_cuda_block = 2;
case 5: break;
case 6: default:
case 7: GGML_ABORT("fatal error");
case 8: break;
nwarps = 2; }
rows_per_cuda_block = 2; }
break; const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
default: const dim3 block_nums(nblocks, 1, 1);
GGML_ABORT("fatal error"); const dim3 block_dims(WARP_SIZE, nwarps, 1);
break;
} switch (ncols_y) {
} case 1:
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block; mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
const dim3 block_nums(nblocks, 1, 1); break;
const dim3 block_dims(WARP_SIZE, nwarps, 1); case 2:
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
switch (ncols_y) { break;
case 1: case 3:
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 2: case 4:
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 3: case 5:
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 4: case 6:
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 5: case 7:
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 6: case 8:
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break; break;
case 7: default:
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); GGML_ABORT("fatal error");
break; break;
case 8: }
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); }
break;
default: static void mul_mat_vec_q4_0_q8_1_cuda(
GGML_ABORT("fatal error"); const void * vx, const void * vy, float * dst,
break; const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
}
} mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q4_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q4_1_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q4_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q5_0_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q5_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q5_1_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q5_1_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q8_0_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q8_0_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q2_K_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q2_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q3_K_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q3_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q4_K_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q4_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q5_K_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q5_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_q6_K_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq2_xs_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq2_s_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq2_s_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq3_xxs_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq3_xxs_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq1_s_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq1_s_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq1_m_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq1_m_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq4_nl_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq4_nl_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq4_xs_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq4_xs_q8_1_cuda(
const void * vx, const void * vy, float * dst, static void mul_mat_vec_iq3_s_q8_1_cuda(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { const void * vx, const void * vy, float * dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) {
mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
} mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream);
}
static void mul_mat_vec_iq3_s_q8_1_cuda(
const void * vx, const void * vy, float * dst, void ggml_cuda_op_mul_mat_vec_q(
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
} const int64_t src1_padded_row_size, cudaStream_t stream) {
void ggml_cuda_op_mul_mat_vec_q( const int64_t ne00 = src0->ne[0];
ggml_backend_cuda_context & ctx, const int64_t row_diff = row_high - row_low;
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t ne10 = src1->ne[0];
const int64_t src1_padded_row_size, cudaStream_t stream) { GGML_ASSERT(ne10 % QK8_1 == 0);
const int64_t ne00 = src0->ne[0]; const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
int id = ggml_cuda_get_device();
const int64_t ne10 = src1->ne[0];
GGML_ASSERT(ne10 % QK8_1 == 0); // the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the kernel writes into
const int64_t ne0 = dst->ne[0]; const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff;
int id = ggml_cuda_get_device(); switch (src0->type) {
case GGML_TYPE_Q4_0:
// the main device has a larger memory buffer to hold the results from all GPUs mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
// nrows_dst == nrows of the matrix that the kernel writes into break;
const int64_t nrows_dst = id == ctx.device ? ne0 : row_diff; case GGML_TYPE_Q4_1:
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
switch (src0->type) { break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q5_0:
mul_mat_vec_q4_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q5_1:
mul_mat_vec_q4_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q8_0:
mul_mat_vec_q5_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q2_K:
mul_mat_vec_q5_1_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q3_K:
mul_mat_vec_q8_0_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q4_K:
mul_mat_vec_q2_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q5_K:
mul_mat_vec_q3_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q6_K:
mul_mat_vec_q4_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q5_K: case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q5_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_Q6_K: case GGML_TYPE_IQ2_XS:
mul_mat_vec_q6_K_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_S:
mul_mat_vec_iq2_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ3_XXS:
mul_mat_vec_iq2_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ1_S:
mul_mat_vec_iq2_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ1_M:
mul_mat_vec_iq3_xxs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ4_NL:
mul_mat_vec_iq1_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ4_XS:
mul_mat_vec_iq1_m_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ3_S:
mul_mat_vec_iq4_nl_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; break;
case GGML_TYPE_IQ4_XS: default:
mul_mat_vec_iq4_xs_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream); GGML_ABORT("fatal error");
break; break;
case GGML_TYPE_IQ3_S: }
mul_mat_vec_iq3_s_q8_1_cuda(src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, src1_padded_row_size, src1_ncols, nrows_dst, stream);
break; GGML_UNUSED(src1);
default: GGML_UNUSED(dst);
GGML_ABORT("fatal error"); GGML_UNUSED(src1_ddf_i);
break; GGML_UNUSED(src1_ncols);
} GGML_UNUSED(src1_padded_row_size);
}
GGML_UNUSED(src1);
GGML_UNUSED(dst); \ No newline at end of file
GGML_UNUSED(src1_ddf_i);
GGML_UNUSED(src1_ncols);
GGML_UNUSED(src1_padded_row_size);
}
...@@ -24,173 +24,174 @@ ...@@ -24,173 +24,174 @@
* SOFTWARE. * SOFTWARE.
*/ */
#include "quantize.cuh" #include "quantize.cuh"
#include <cstdint> #include <cstdint>
static __global__ __launch_bounds__(1024) void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { static __global__ __launch_bounds__(1024) void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) {
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (ix0 >= kx0_padded) { if (ix0 >= kx0_padded) {
return; return;
} }
const int64_t ix1 = blockIdx.y; const int64_t ix1 = blockIdx.y;
const int64_t i_padded = ix1*kx0_padded + ix0; const int64_t i_padded = ix1*kx0_padded + ix0;
block_q8_1 * y = (block_q8_1 *) vy; block_q8_1 * y = (block_q8_1 *) vy;
const int64_t ib = i_padded / QK8_1; // block index const int64_t ib = i_padded / QK8_1; // block index
const int64_t iqs = i_padded % QK8_1; // quant index const int64_t iqs = i_padded % QK8_1; // quant index
const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f;
float amax = fabsf(xi); float amax = fabsf(xi);
float sum = xi; float sum = xi;
amax = warp_reduce_max(amax); amax = warp_reduce_max(amax);
sum = warp_reduce_sum(sum); sum = warp_reduce_sum(sum);
const float d = amax / 127; const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q; y[ib].qs[iqs] = q;
if (iqs > 0) { if (iqs > 0) {
return; return;
} }
ggml_half2 ds = {d, sum};
y[ib].ds = ds; y[ib].ds = ggml_half2 {d, sum};
//reinterpret_cast<half&>(y[ib].ds) = ds; //reinterpret_cast<half&>(y[ib].ds) = ds;
//reinterpret_cast<half&>(y[ib].ds.y) = sum; //reinterpret_cast<half&>(y[ib].ds.y) = sum;
} }
template <mmq_q8_1_ds_layout ds_layout> template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1( static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) {
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
if (ix0 >= kx0_padded) { if (ix0 >= kx0_padded) {
return; return;
} }
const float4 * x4 = (const float4 *) x; const float4 * x4 = (const float4 *) x;
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; const int64_t ix1 = kx1*blockIdx.z + blockIdx.y;
block_q8_1_mmq * y = (block_q8_1_mmq *) vy; block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block const int64_t iqs = ix0 % (4*QK8_1); // quant index in block
// Load 4 floats per thread and calculate max. abs. value between them: // Load 4 floats per thread and calculate max. abs. value between them:
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x); float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y)); amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z)); amax = fmaxf(amax, fabsf(xi.z));
amax = fmaxf(amax, fabsf(xi.w)); amax = fmaxf(amax, fabsf(xi.w));
// Exchange max. abs. value between vals_per_scale/4 threads. // Exchange max. abs. value between vals_per_scale/4 threads.
#pragma unroll #pragma unroll
for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) { for (int offset = vals_per_scale/8; offset > 0; offset >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE)); amax = fmaxf(amax, __shfl_xor_sync(0xFFFFFFFF, amax, offset, WARP_SIZE));
} }
float sum; float sum;
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w; sum = xi.x + xi.y + xi.z + xi.w;
// Exchange calculate sum across vals_per_sum/4 threads. // Exchange calculate sum across vals_per_sum/4 threads.
#pragma unroll #pragma unroll
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
} }
} }
const float d_inv = 127.0f / amax; const float d_inv = 127.0f / amax;
char4 q; char4 q;
q.x = roundf(xi.x*d_inv); q.x = roundf(xi.x*d_inv);
q.y = roundf(xi.y*d_inv); q.y = roundf(xi.y*d_inv);
q.z = roundf(xi.z*d_inv); q.z = roundf(xi.z*d_inv);
q.w = roundf(xi.w*d_inv); q.w = roundf(xi.w*d_inv);
// Write back 4 int8 values as a single 32 bit value for better memroy bandwidth: // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
char4 * yqs4 = (char4 *) y[ib].qs; char4 * yqs4 = (char4 *) y[ib].qs;
yqs4[iqs/4] = q; yqs4[iqs/4] = q;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) { if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6) {
if (iqs % 16 != 0 || iqs >= 96) { if (iqs % 16 != 0 || iqs >= 96) {
return; return;
} }
y[ib].d2s6[2 + iqs/16] = sum; y[ib].d2s6[2 + iqs/16] = sum;
if (iqs % 64 != 0) { if (iqs % 64 != 0) {
return; return;
} }
const float d = 1.0f / d_inv; const float d = 1.0f / d_inv;
y[ib].d2s6[iqs/64] = d; y[ib].d2s6[iqs/64] = d;
return; return;
} }
if (iqs % 32 != 0) { if (iqs % 32 != 0) {
return; return;
} }
const float d = 1.0f / d_inv; const float d = 1.0f / d_inv;
if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) { if (ds_layout == MMQ_Q8_1_DS_LAYOUT_DS4) {
y[ib].ds4[iqs/32] = make_half2(d, sum); y[ib].ds4[iqs/32] = make_half2(d, sum);
} else { } else {
y[ib].d4[iqs/32] = d; y[ib].d4[iqs/32] = d;
} }
} }
void quantize_row_q8_1_cuda( void quantize_row_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
GGML_ASSERT(kx0_padded % QK8_1 == 0); GGML_ASSERT(kx0_padded % QK8_1 == 0);
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, kx1*channels, 1); const dim3 num_blocks(block_num_x, kx1*channels, 1);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded); quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded);
GGML_UNUSED(type_x); GGML_UNUSED(type_x);
} }
void quantize_mmq_q8_1_cuda( void quantize_mmq_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels,
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) {
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); GGML_ASSERT(kx0_padded % (4*QK8_1) == 0);
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, kx1, channels); const dim3 num_blocks(block_num_x, kx1, channels);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_x)) { switch (mmq_get_q8_1_ds_layout(type_x)) {
case MMQ_Q8_1_DS_LAYOUT_D4: case MMQ_Q8_1_DS_LAYOUT_D4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4> quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break; break;
case MMQ_Q8_1_DS_LAYOUT_DS4: case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4> quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break; break;
case MMQ_Q8_1_DS_LAYOUT_D2S6: case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6> quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded);
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break; break;
} }
} }
\ No newline at end of file
...@@ -505,7 +505,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( ...@@ -505,7 +505,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq(
float sumf = 0.0f; float sumf = 0.0f;
#pragma unroll #pragma unroll QR6_K
for (int i = 0; i < QR6_K; ++i) { for (int i = 0; i < QR6_K; ++i) {
const int sc = scales[4*i]; const int sc = scales[4*i];
...@@ -803,7 +803,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( ...@@ -803,7 +803,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
int u[QR6_K]; int u[QR6_K];
float d8[QR6_K]; float d8[QR6_K];
#pragma unroll #pragma unroll QR6_K
for (int i = 0; i < QR6_K; ++i) { for (int i = 0; i < QR6_K; ++i) {
u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); u[i] = get_int_b4(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds); d8[i] = __low2float(bq8_1[bq8_offset + 2*i].ds);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment