Unverified Commit 0cf7794b authored by Daniel Hiltgen's avatar Daniel Hiltgen Committed by GitHub
Browse files

ggml update to b7108 (#12992)

* Revert "vulkan: temporary cary of vulkan fixes (#12971)"

This reverts commit 3a9e8e9f.

* ggml update to b7087

* fix argsort on metal

* update to b7108

* fix bakllava regression

This model lacks the metadata for the projector type.

* update to b7209

* fix TopK perf

* only build arm code on arm
parent 854d40ed
#include "ggml.h" #include "ggml.h"
#include "common.cuh" #include "common.cuh"
#include "convert.cuh" #include "unary.cuh"
#include "mmvf.cuh" #include "mmvf.cuh"
#include "convert.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size> template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false>
static __global__ void mul_mat_vec_f( static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst, const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
...@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f( ...@@ -24,58 +25,164 @@ static __global__ void mul_mat_vec_f(
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
bool use_gate = false;
bool use_bias = false;
bool use_gate_bias = false;
ggml_glu_op glu_op = ggml_glu_op::GGML_GLU_OP_SWIGLU;
const T * gate_x = nullptr;
const float * x_bias = nullptr;
const float * gate_bias = nullptr;
if constexpr (has_fusion) {
use_gate = fusion.gate != nullptr;
use_bias = fusion.x_bias != nullptr;
use_gate_bias = fusion.gate_bias != nullptr;
glu_op = fusion.glu_op;
if (use_gate) {
gate_x = static_cast<const T *>(fusion.gate);
}
if (use_bias) {
x_bias = static_cast<const float *>(fusion.x_bias);
}
if (use_gate_bias) {
gate_bias = static_cast<const float *>(fusion.gate_bias);
use_gate_bias = use_gate;
} else {
use_gate_bias = false;
}
}
if (use_gate) {
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
}
if constexpr (has_fusion) {
const int channel_bias = ids ? channel_x : channel_dst;
if (use_bias) {
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
if (use_gate_bias) {
gate_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
}
}
const float2 * y2 = (const float2 *) y; const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[]; extern __shared__ char data_mmv[];
float * buf_iw = (float *) data_mmv; float * buf_iw = (float *) data_mmv;
float * buf_iw_gate = nullptr;
if constexpr (has_fusion) {
buf_iw_gate = (float *) (data_mmv + warp_size*sizeof(float));
}
if (block_size > warp_size) { if (block_size > warp_size) {
if (tid < warp_size) { if (tid < warp_size) {
buf_iw[tid] = 0.0f; buf_iw[tid] = 0.0f;
if constexpr (has_fusion) {
if (use_gate) {
buf_iw_gate[tid] = 0.0f;
}
}
} }
__syncthreads(); __syncthreads();
} }
float sumf[ncols_dst] = {0.0f}; float sumf[ncols_dst] = {0.0f};
float sumf_gate[ncols_dst];
if constexpr (has_fusion) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf_gate[j] = 0.0f;
}
}
if constexpr (std::is_same_v<T, float>) { if constexpr (std::is_same_v<T, float>) {
const float2 * x2 = (const float2 *) x; const float2 * x2 = (const float2 *) x;
const float2 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const float2 *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = x2[col2]; const float2 tmpx = x2[col2];
float2 tmpx_gate = make_float2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
} }
} }
} else if constexpr (std::is_same_v<T, half>) { } else if constexpr (std::is_same_v<T, half>) {
const half2 * x2 = (const half2 *) x; const half2 * x2 = (const half2 *) x;
const half2 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const half2 *) gate_x;
}
}
if (std::is_same_v<type_acc, float>) { if (std::is_same_v<type_acc, float>) {
for (int col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const float2 tmpx = __half22float2(x2[col2]); const float2 tmpx = __half22float2(x2[col2]);
float2 tmpx_gate = make_float2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = __half22float2(gate_x2[col2]);
}
}
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
} }
} }
} else { } else {
#ifdef FP16_AVAILABLE #ifdef FP16_AVAILABLE
half2 sumh2[ncols_dst] = {{0.0f, 0.0f}}; half2 sumh2[ncols_dst] = {{0.0f, 0.0f}};
half2 sumh2_gate[ncols_dst] = {{0.0f, 0.0f}};
for (int col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const half2 tmpx = x2[col2]; const half2 tmpx = x2[col2];
half2 tmpx_gate = make_half2(0.0f, 0.0f);
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y); sumh2[j] += tmpx * make_half2(tmpy.x, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
sumh2_gate[j] += tmpx_gate * make_half2(tmpy.x, tmpy.y);
}
}
} }
} }
...@@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f( ...@@ -83,6 +190,15 @@ static __global__ void mul_mat_vec_f(
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]); sumf[j] = __low2float(sumh2[j]) + __high2float(sumh2[j]);
} }
if constexpr (has_fusion) {
if (use_gate) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
sumf_gate[j] = __low2float(sumh2_gate[j]) + __high2float(sumh2_gate[j]);
}
}
}
#else #else
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // FP16_AVAILABLE #endif // FP16_AVAILABLE
...@@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f( ...@@ -91,8 +207,20 @@ static __global__ void mul_mat_vec_f(
//TODO: add support for ggml_cuda_mad for hip_bfloat162 //TODO: add support for ggml_cuda_mad for hip_bfloat162
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
const int * x2 = (const int *) x; const int * x2 = (const int *) x;
const int * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const int *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const int tmpx = x2[col2]; const int tmpx = x2[col2];
int tmpx_gate = 0;
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
...@@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f( ...@@ -100,17 +228,45 @@ static __global__ void mul_mat_vec_f(
const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]); const float tmpx1 = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx)[1]);
ggml_cuda_mad(sumf[j], tmpx0, tmpy.x); ggml_cuda_mad(sumf[j], tmpx0, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx1, tmpy.y); ggml_cuda_mad(sumf[j], tmpx1, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
const float tmpx0_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[0]);
const float tmpx1_gate = ggml_cuda_cast<float>(reinterpret_cast<const nv_bfloat16 *>(&tmpx_gate)[1]);
ggml_cuda_mad(sumf_gate[j], tmpx0_gate, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx1_gate, tmpy.y);
}
}
} }
} }
#else #else
const nv_bfloat162 * x2 = (const nv_bfloat162 *) x; const nv_bfloat162 * x2 = (const nv_bfloat162 *) x;
const nv_bfloat162 * gate_x2 = nullptr;
if constexpr (has_fusion) {
if (use_gate) {
gate_x2 = (const nv_bfloat162 *) gate_x;
}
}
for (int col2 = tid; col2 < ncols2; col2 += block_size) { for (int col2 = tid; col2 < ncols2; col2 += block_size) {
const nv_bfloat162 tmpx = x2[col2]; const nv_bfloat162 tmpx = x2[col2];
nv_bfloat162 tmpx_gate;
if constexpr (has_fusion) {
if (use_gate) {
tmpx_gate = gate_x2[col2];
}
}
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
const float2 tmpy = y2[j*stride_col_y2 + col2]; const float2 tmpy = y2[j*stride_col_y2 + col2];
ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x); ggml_cuda_mad(sumf[j], tmpx.x, tmpy.x);
ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y); ggml_cuda_mad(sumf[j], tmpx.y, tmpy.y);
if constexpr (has_fusion) {
if (use_gate) {
ggml_cuda_mad(sumf_gate[j], tmpx_gate.x, tmpy.x);
ggml_cuda_mad(sumf_gate[j], tmpx_gate.y, tmpy.y);
}
}
} }
} }
#endif #endif
...@@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f( ...@@ -122,13 +278,31 @@ static __global__ void mul_mat_vec_f(
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]); sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if constexpr (has_fusion) {
if (use_gate) {
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
}
}
if (block_size > warp_size) { if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf[j]; buf_iw[tid/warp_size] = sumf[j];
if constexpr (has_fusion) {
if (use_gate) {
buf_iw_gate[tid/warp_size] = sumf_gate[j];
}
}
__syncthreads(); __syncthreads();
if (tid < warp_size) { if (tid < warp_size) {
sumf[j] = buf_iw[tid]; sumf[j] = buf_iw[tid];
sumf[j] = warp_reduce_sum<warp_size>(sumf[j]); sumf[j] = warp_reduce_sum<warp_size>(sumf[j]);
if constexpr (has_fusion) {
if (use_gate) {
sumf_gate[j] = buf_iw_gate[tid];
sumf_gate[j] = warp_reduce_sum<warp_size>(sumf_gate[j]);
}
}
} }
if (j < ncols_dst) { if (j < ncols_dst) {
__syncthreads(); __syncthreads();
} }
...@@ -139,12 +313,74 @@ static __global__ void mul_mat_vec_f( ...@@ -139,12 +313,74 @@ static __global__ void mul_mat_vec_f(
return; return;
} }
dst[tid*stride_col_dst + row] = sumf[tid]; float value = sumf[tid];
if constexpr (has_fusion) {
if (use_bias) {
value += x_bias[tid*stride_col_dst + row];
}
if (use_gate) {
float gate_value = sumf_gate[tid];
if (use_gate_bias) {
gate_value += gate_bias[tid*stride_col_dst + row];
}
switch (glu_op) {
case GGML_GLU_OP_SWIGLU:
value *= ggml_cuda_op_silu_single(gate_value);
break;
case GGML_GLU_OP_GEGLU:
value *= ggml_cuda_op_gelu_single(gate_value);
break;
case GGML_GLU_OP_SWIGLU_OAI: {
value = ggml_cuda_op_swiglu_oai_single(gate_value, value);
break;
}
default:
break;
}
}
}
dst[tid*stride_col_dst + row] = value;
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, glu_op, gate_x, x_bias, gate_bias, sumf_gate);
}
}
template<typename T, typename type_acc, int ncols_dst, int block_size>
static void mul_mat_vec_f_switch_fusion(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} }
template <typename T, typename type_acc, int ncols_dst> template <typename T, typename type_acc, int ncols_dst>
static void launch_mul_mat_vec_f_cuda( void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols, const int64_t nrows,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
...@@ -176,57 +412,59 @@ static void launch_mul_mat_vec_f_cuda( ...@@ -176,57 +412,59 @@ static void launch_mul_mat_vec_f_cuda(
} }
} }
const int nbytes_shared = warp_size*sizeof(float); const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_nums(nrows, nchannels_dst, nsamples_dst);
const dim3 block_dims(block_size_best, 1, 1); const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) { switch (block_size_best) {
case 32: { case 32: {
mul_mat_vec_f<T, type_acc, ncols_dst, 32><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
case 64: { case 64: {
mul_mat_vec_f<T, type_acc, ncols_dst, 64><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
case 96: { case 96: {
mul_mat_vec_f<T, type_acc, ncols_dst, 96><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
case 128: { case 128: {
mul_mat_vec_f<T, type_acc, ncols_dst, 128><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
case 160: { case 160: {
mul_mat_vec_f<T, type_acc, ncols_dst, 160><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
case 192: { case 192: {
mul_mat_vec_f<T, type_acc, ncols_dst, 192><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
case 224: { case 224: {
mul_mat_vec_f<T, type_acc, ncols_dst, 224><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
case 256: { case 256: {
mul_mat_vec_f<T, type_acc, ncols_dst, 256><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256>
(x, y, ids, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
...@@ -236,7 +474,7 @@ static void launch_mul_mat_vec_f_cuda( ...@@ -236,7 +474,7 @@ static void launch_mul_mat_vec_f_cuda(
template <typename T, typename type_acc> template <typename T, typename type_acc>
static void mul_mat_vec_f_cuda_switch_ncols_dst( static void mul_mat_vec_f_cuda_switch_ncols_dst(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
...@@ -246,49 +484,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( ...@@ -246,49 +484,49 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
switch (ncols_dst) { switch (ncols_dst) {
case 1: case 1:
launch_mul_mat_vec_f_cuda<T, type_acc, 1> launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
case 2: case 2:
launch_mul_mat_vec_f_cuda<T, type_acc, 2> launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
case 3: case 3:
launch_mul_mat_vec_f_cuda<T, type_acc, 3> launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
case 4: case 4:
launch_mul_mat_vec_f_cuda<T, type_acc, 4> launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
case 5: case 5:
launch_mul_mat_vec_f_cuda<T, type_acc, 5> launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
case 6: case 6:
launch_mul_mat_vec_f_cuda<T, type_acc, 6> launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
case 7: case 7:
launch_mul_mat_vec_f_cuda<T, type_acc, 7> launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
case 8: case 8:
launch_mul_mat_vec_f_cuda<T, type_acc, 8> launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
break; break;
...@@ -300,29 +538,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst( ...@@ -300,29 +538,31 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
template<typename T> template<typename T>
static void mul_mat_vec_f_cuda( static void mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols_dst, const int64_t ncols, const int64_t nrows, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst, const int64_t stride_row, const int64_t stride_col_y, const int stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) { enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same_v<T, half>) { if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) { if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_f_cuda_switch_ncols_dst<T, half> mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
return; return;
} }
} }
mul_mat_vec_f_cuda_switch_ncols_dst<T, float> mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} }
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion) {
GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
...@@ -348,6 +588,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor ...@@ -348,6 +588,30 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
ggml_cuda_mm_fusion_args_device fusion_local{};
if (fusion) {
GGML_ASSERT( !ids || dst->ne[2] == 1);
GGML_ASSERT( ids || dst->ne[1] == 1);
if (fusion->x_bias) {
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
fusion_local.x_bias = fusion->x_bias->data;
}
if (fusion->gate) {
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
fusion_local.gate = fusion->gate->data;
}
if (fusion->gate_bias) {
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
fusion_local.gate_bias = fusion->gate_bias->data;
}
fusion_local.glu_op = fusion->glu_op;
}
const int64_t s01 = src0->nb[1] / ts_src0; const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s11 = src1->nb[1] / ts_src1; const int64_t s11 = src1->nb[1] / ts_src1;
const int64_t s1 = dst->nb[1] / ts_dst; const int64_t s1 = dst->nb[1] / ts_dst;
...@@ -370,19 +634,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor ...@@ -370,19 +634,19 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data; const half * src0_d = (const half *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, prec, ctx.stream());
} break; } break;
...@@ -409,7 +673,6 @@ void ggml_cuda_op_mul_mat_vec_f( ...@@ -409,7 +673,6 @@ void ggml_cuda_op_mul_mat_vec_f(
const int cc = ggml_cuda_info().devices[id].cc; const int cc = ggml_cuda_info().devices[id].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
// ggml_cuda_op provides single, contiguous matrices // ggml_cuda_op provides single, contiguous matrices
const int64_t stride_row = ne00; const int64_t stride_row = ne00;
const int64_t stride_col_y = ne10; const int64_t stride_col_y = ne10;
...@@ -426,22 +689,23 @@ void ggml_cuda_op_mul_mat_vec_f( ...@@ -426,22 +689,23 @@ void ggml_cuda_op_mul_mat_vec_f(
const int64_t stride_sample_y = 0; const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0; const int64_t stride_sample_dst = 0;
ggml_cuda_mm_fusion_args_device empty{};
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0_dd_i; const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i; const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break; } break;
...@@ -452,10 +716,23 @@ void ggml_cuda_op_mul_mat_vec_f( ...@@ -452,10 +716,23 @@ void ggml_cuda_op_mul_mat_vec_f(
GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size); GGML_UNUSED_VARS(ctx, src1, dst, src1_ddq_i, src1_ncols, src1_padded_row_size);
} }
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11) { bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11) {
if (src0_ne[0] % 2 != 0) { if (src0_ne[0] % 2 != 0) {
return false; return false;
} }
const size_t ts = ggml_type_size(type);
if (src0_nb[0] != ts) {
return false;
}
// Pointers not aligned to the size of half2/nv_bfloat162/float2 would result in a crash:
for (size_t i = 1; i < GGML_MAX_DIMS; ++i) {
if (src0_nb[i] % (2*ts) != 0) {
return false;
}
}
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
if (GGML_CUDA_CC_IS_NVIDIA(cc)) { if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
......
#include "common.cuh" #include "common.cuh"
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
void ggml_cuda_op_mul_mat_vec_f( void ggml_cuda_op_mul_mat_vec_f(
ggml_backend_cuda_context & ctx, ggml_backend_cuda_context & ctx,
...@@ -8,4 +9,4 @@ void ggml_cuda_op_mul_mat_vec_f( ...@@ -8,4 +9,4 @@ void ggml_cuda_op_mul_mat_vec_f(
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream); const int64_t src1_padded_row_size, cudaStream_t stream);
bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, int64_t ne11); bool ggml_cuda_should_use_mmvf(enum ggml_type type, int cc, const int64_t * src0_ne, const size_t * src0_nb, int64_t ne11);
#include "mmvq.cuh" #include "mmvq.cuh"
#include "quantize.cuh" #include "quantize.cuh"
#include "unary.cuh"
#include "vecdotq.cuh" #include "vecdotq.cuh"
#include <cstdint> #include <cstdint>
...@@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) { ...@@ -82,7 +83,7 @@ static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
return MMVQ_PARAMETERS_GENERIC; return MMVQ_PARAMETERS_GENERIC;
} }
static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) { static constexpr __host__ __device__ int calc_nwarps(int ncols_dst, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) { if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_dst) { switch (ncols_dst) {
case 1: case 1:
...@@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int ...@@ -136,11 +137,11 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1; return 1;
} }
template <ggml_type type, int ncols_dst>
// tell the compiler to use as many registers as it wants, see nwarps definition below // tell the compiler to use as many registers as it wants, see nwarps definition below
template <ggml_type type, int ncols_dst, bool has_fusion>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
...@@ -169,8 +170,56 @@ static __global__ void mul_mat_vec_q( ...@@ -169,8 +170,56 @@ static __global__ void mul_mat_vec_q(
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst; const uint32_t sample_y = sample_dst;
bool use_gate = false;
bool use_bias = false;
bool use_gate_bias = false;
const void * vgate = nullptr;
const float * x_bias = nullptr;
const float * gate_bias = nullptr;
ggml_glu_op active_glu;
if constexpr (has_fusion) {
use_gate = fusion.gate != nullptr;
use_bias = fusion.x_bias != nullptr;
use_gate_bias = fusion.gate_bias != nullptr && use_gate;
vgate = fusion.gate;
x_bias = (const float *) fusion.x_bias;
gate_bias = (const float *) fusion.gate_bias;
active_glu = fusion.glu_op;
}
const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst] = { 0.0f };
float gate_biases[ncols_dst] = { 0.0f };
if constexpr (has_fusion) {
if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here
// 2. load only on threads that won't die after partial sum calculation
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
x_biases[j] = x_bias[j * stride_col_dst + threadIdx.x];
}
}
}
if (use_gate_bias) {
gate_bias = gate_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
if (threadIdx.x < rows_per_cuda_block && threadIdx.y == 0 &&
(rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
#pragma unroll
for (int j = 0; j < ncols_dst; ++j) {
gate_biases[j] = gate_bias[j * stride_col_dst + threadIdx.x];
}
}
}
}
// partial sum for each thread // partial sum for each thread
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
...@@ -187,17 +236,35 @@ static __global__ void mul_mat_vec_q( ...@@ -187,17 +236,35 @@ static __global__ void mul_mat_vec_q(
for (int i = 0; i < rows_per_cuda_block; ++i) { for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp[j][i] += vec_dot_q_cuda( tmp[j][i] += vec_dot_q_cuda(
vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs); vx, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] += vec_dot_q_cuda(
vgate, &y[j*stride_col_y + kby], kbx_offset + i*stride_row_x + kbx, kqs);
}
}
} }
} }
} }
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size]; __shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
__shared__ float tmp_shared_gate[(has_fusion && (nwarps-1 > 0)) ? nwarps-1 : 1][ncols_dst][rows_per_cuda_block][warp_size];
if constexpr (!has_fusion) {
(void) tmp_shared_gate;
} else if (!use_gate) {
(void) tmp_shared_gate;
}
if (threadIdx.y > 0) { if (threadIdx.y > 0) {
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
#pragma unroll #pragma unroll
for (int i = 0; i < rows_per_cuda_block; ++i) { for (int i = 0; i < rows_per_cuda_block; ++i) {
tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i]; tmp_shared[threadIdx.y-1][j][i][threadIdx.x] = tmp[j][i];
if constexpr (has_fusion) {
if (use_gate) {
tmp_shared_gate[threadIdx.y-1][j][i][threadIdx.x] = tmp_gate[j][i];
}
}
} }
} }
} }
...@@ -216,14 +283,55 @@ static __global__ void mul_mat_vec_q( ...@@ -216,14 +283,55 @@ static __global__ void mul_mat_vec_q(
#pragma unroll #pragma unroll
for (int l = 0; l < nwarps-1; ++l) { for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x]; tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] += tmp_shared_gate[l][j][i][threadIdx.x];
}
}
} }
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]); tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
if constexpr (has_fusion) {
if (use_gate) {
tmp_gate[j][i] = warp_reduce_sum<warp_size>(tmp_gate[j][i]);
}
}
} }
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) { if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || uint32_t(row0 + threadIdx.x) < stride_col_dst)) {
dst[j*stride_col_dst + threadIdx.x] = tmp[j][threadIdx.x]; float result = tmp[j][threadIdx.x];
if constexpr (has_fusion) {
if (use_bias) {
result += x_biases[j];
}
if (use_gate) {
float gate_value = tmp_gate[j][threadIdx.x];
if (use_gate_bias) {
gate_value += gate_biases[j];
}
switch (active_glu) {
case GGML_GLU_OP_SWIGLU:
result *= ggml_cuda_op_silu_single(gate_value);
break;
case GGML_GLU_OP_GEGLU:
result *= ggml_cuda_op_gelu_single(gate_value);
break;
case GGML_GLU_OP_SWIGLU_OAI: {
result = ggml_cuda_op_swiglu_oai_single(gate_value, result);
break;
}
default:
result = result * gate_value;
break;
}
}
}
dst[j*stride_col_dst + threadIdx.x] = result;
} }
} }
if constexpr (!has_fusion) {
GGML_UNUSED_VARS(use_gate, use_bias, use_gate_bias, active_glu, gate_bias, x_bias, tmp_gate);
}
} }
static std::pair<dim3, dim3> calc_launch_params( static std::pair<dim3, dim3> calc_launch_params(
...@@ -235,9 +343,37 @@ static std::pair<dim3, dim3> calc_launch_params( ...@@ -235,9 +343,37 @@ static std::pair<dim3, dim3> calc_launch_params(
return {block_nums, block_dims}; return {block_nums, block_dims};
} }
template<ggml_type type, int c_ncols_dst>
static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) {
if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
return;
}
}
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
template <ggml_type type> template <ggml_type type>
static void mul_mat_vec_q_switch_ncols_dst( static void mul_mat_vec_q_switch_ncols_dst(
const void * vx, const void * vy, const int32_t * ids, float * dst, const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst, const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst,
...@@ -256,80 +392,83 @@ static void mul_mat_vec_q_switch_ncols_dst( ...@@ -256,80 +392,83 @@ static void mul_mat_vec_q_switch_ncols_dst(
const int warp_size = ggml_cuda_info().devices[device].warp_size; const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
GGML_ASSERT(!ids || ncols_dst == 1); GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) { switch (ncols_dst) {
case 1: { case 1: {
constexpr int c_ncols_dst = 1; constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
case 2: { case 2: {
constexpr int c_ncols_dst = 2; constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
case 3: { case 3: {
constexpr int c_ncols_dst = 3; constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
case 4: { case 4: {
constexpr int c_ncols_dst = 4; constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
case 5: { case 5: {
constexpr int c_ncols_dst = 5; constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
case 6: { case 6: {
constexpr int c_ncols_dst = 6; constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
case 7: { case 7: {
constexpr int c_ncols_dst = 7; constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
case 8: { case 8: {
constexpr int c_ncols_dst = 8; constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream);
} break; } break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break; break;
} }
}
GGML_UNUSED(has_fusion);
}
static void mul_mat_vec_q_switch_type( static void mul_mat_vec_q_switch_type(
const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, float * dst, const void * vx, const ggml_type type_x, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int ncols_x, const int nrows_x, const int ncols_dst, const int ncols_x, const int nrows_x, const int ncols_dst,
const int stride_row_x, const int stride_col_y, const int stride_col_dst, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst,
...@@ -339,143 +478,123 @@ static void mul_mat_vec_q_switch_type( ...@@ -339,143 +478,123 @@ static void mul_mat_vec_q_switch_type(
switch (type_x) { switch (type_x) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_MXFP4: case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_M:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
(vx, vy, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
stream);
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
...@@ -484,7 +603,8 @@ static void mul_mat_vec_q_switch_type( ...@@ -484,7 +603,8 @@ static void mul_mat_vec_q_switch_type(
} }
void ggml_cuda_mul_mat_vec_q( void ggml_cuda_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion) {
GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID. GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); // Optional, used for batched GGML_MUL_MAT_ID.
...@@ -508,6 +628,31 @@ void ggml_cuda_mul_mat_vec_q( ...@@ -508,6 +628,31 @@ void ggml_cuda_mul_mat_vec_q(
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
float * dst_d = (float *) dst->data; float * dst_d = (float *) dst->data;
ggml_cuda_mm_fusion_args_device fusion_local{};
if (fusion) {
GGML_ASSERT( !ids || dst->ne[2] == 1);
GGML_ASSERT( ids || dst->ne[1] == 1);
if (fusion->x_bias) {
GGML_ASSERT(fusion->x_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->x_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->x_bias->ne[1] == src0->ne[2]);
fusion_local.x_bias = fusion->x_bias->data;
}
if (fusion->gate) {
GGML_ASSERT(fusion->gate->type == src0->type && ggml_are_same_stride(fusion->gate, src0));
fusion_local.gate = fusion->gate->data;
}
if (fusion->gate_bias) {
GGML_ASSERT(fusion->gate_bias->type == GGML_TYPE_F32);
GGML_ASSERT(fusion->gate_bias->ne[0] == dst->ne[0]);
GGML_ASSERT(!ids || fusion->gate_bias->ne[1] == src0->ne[2]);
fusion_local.gate_bias = fusion->gate_bias->data;
}
fusion_local.glu_op = fusion->glu_op;
}
// If src0 is a temporary compute buffer, clear any potential padding. // If src0 is a temporary compute buffer, clear any potential padding.
if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) { if (ggml_backend_buffer_get_usage(src0->buffer) == GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
const size_t size_data = ggml_nbytes(src0); const size_t size_data = ggml_nbytes(src0);
...@@ -549,10 +694,10 @@ void ggml_cuda_mul_mat_vec_q( ...@@ -549,10 +694,10 @@ void ggml_cuda_mul_mat_vec_q(
const int64_t stride_channel_y = ids ? s11 : s12; const int64_t stride_channel_y = ids ? s11 : s12;
mul_mat_vec_q_switch_type( mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, dst_d, ne00, src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, stream); ne03, ne3, s03, s13, s3, stream);
} }
void ggml_cuda_op_mul_mat_vec_q( void ggml_cuda_op_mul_mat_vec_q(
...@@ -578,8 +723,9 @@ void ggml_cuda_op_mul_mat_vec_q( ...@@ -578,8 +723,9 @@ void ggml_cuda_op_mul_mat_vec_q(
const int stride_row_x = ne00 / ggml_blck_size(src0->type); const int stride_row_x = ne00 / ggml_blck_size(src0->type);
const int stride_col_y = src1_padded_row_size / QK8_1; const int stride_col_y = src1_padded_row_size / QK8_1;
ggml_cuda_mm_fusion_args_device fusion_local{};
mul_mat_vec_q_switch_type( mul_mat_vec_q_switch_type(
src0_dd_i, src0->type, src1_ddq_i, nullptr, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream);
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
......
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr);
void ggml_cuda_op_mul_mat_vec_q( void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, ggml_backend_cuda_context & ctx,
......
#include "convert.cuh"
#include "ggml-cuda/common.cuh"
#include "ggml.h"
#include "rope.cuh" #include "rope.cuh"
struct rope_corr_dims { struct rope_corr_dims {
...@@ -37,11 +40,23 @@ static __device__ void rope_yarn( ...@@ -37,11 +40,23 @@ static __device__ void rope_yarn(
} }
} }
template<bool forward, bool has_ff, typename T> template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_norm( static __global__ void rope_norm(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, D * dst,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const int ne0,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int ne1,
const int s1,
const int s2,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne0) {
...@@ -53,13 +68,27 @@ static __global__ void rope_norm( ...@@ -53,13 +68,27 @@ static __global__ void rope_norm(
const int row_x = row_dst % ne1; const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1; const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0; int idst = row_dst * ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0; const int ix = channel_x*s2 + row_x*s1 + i0;
if (i0 >= n_dims) { // Fusion optimization: ROPE + VIEW + SET_ROWS.
dst[idst + 0] = x[ix + 0]; // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
dst[idst + 1] = x[ix + 1]; if (set_rows_stride != 0) {
idst = row_x * ne0 + i0;
idst += row_indices[channel_x] * set_rows_stride;
}
const auto & store_coaelsced = [&](float x0, float x1) {
if constexpr (std::is_same_v<float, D>) {
float2 v = make_float2(x0, x1);
ggml_cuda_memcpy_1<8>(dst + idst, &v);
} else if constexpr (std::is_same_v<half, D>) {
half2 v = make_half2(x0, x1);
ggml_cuda_memcpy_1<4>(dst + idst, &v);
}
};
if (i0 >= n_dims) {
store_coaelsced(x[ix + 0], x[ix + 1]);
return; return;
} }
...@@ -75,15 +104,26 @@ static __global__ void rope_norm( ...@@ -75,15 +104,26 @@ static __global__ void rope_norm(
const float x0 = x[ix + 0]; const float x0 = x[ix + 0];
const float x1 = x[ix + 1]; const float x1 = x[ix + 1];
dst[idst + 0] = x0*cos_theta - x1*sin_theta; store_coaelsced(x0 * cos_theta - x1 * sin_theta, x0 * sin_theta + x1 * cos_theta);
dst[idst + 1] = x0*sin_theta + x1*cos_theta;
} }
template<bool forward, bool has_ff, typename T> template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_neox( static __global__ void rope_neox(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, D * dst,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const int ne0,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors) { const int ne1,
const int s1,
const int s2,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne0) {
...@@ -95,12 +135,19 @@ static __global__ void rope_neox( ...@@ -95,12 +135,19 @@ static __global__ void rope_neox(
const int row_x = row_dst % ne1; const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1; const int channel_x = row_dst / ne1;
const int idst = row_dst*ne0 + i0/2; int idst = row_dst * ne0 + i0 / 2;
const int ix = channel_x*s2 + row_x*s1 + i0/2; const int ix = channel_x*s2 + row_x*s1 + i0/2;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0 / 2;
idst += row_indices[channel_x] * set_rows_stride;
}
if (i0 >= n_dims) { if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; dst[idst + i0 / 2 + 0] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 0]);
dst[idst + i0/2 + 1] = x[ix + i0/2 + 1]; dst[idst + i0 / 2 + 1] = ggml_cuda_cast<D>(x[ix + i0 / 2 + 1]);
return; return;
} }
...@@ -117,15 +164,15 @@ static __global__ void rope_neox( ...@@ -117,15 +164,15 @@ static __global__ void rope_neox(
const float x0 = x[ix + 0]; const float x0 = x[ix + 0];
const float x1 = x[ix + n_dims/2]; const float x1 = x[ix + n_dims/2];
dst[idst + 0] = x0*cos_theta - x1*sin_theta; dst[idst + 0] = ggml_cuda_cast<D>(x0 * cos_theta - x1 * sin_theta);
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
} }
template<bool forward, bool has_ff, typename T> template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi( static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections) { const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) { if (i0 >= ne0) {
...@@ -151,12 +198,30 @@ static __global__ void rope_multi( ...@@ -151,12 +198,30 @@ static __global__ void rope_multi(
const int sec_w = sections.v[1] + sections.v[0]; const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims; const int sector = (i0 / 2) % sect_dims;
float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); float theta_base = 0.0;
if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { if (is_imrope) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { // h
} theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { } else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
// } else {
// theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
}
} }
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
...@@ -220,11 +285,25 @@ static __global__ void rope_vision( ...@@ -220,11 +285,25 @@ static __global__ void rope_vision(
dst[idst + n_dims] = x0*sin_theta + x1*cos_theta; dst[idst + n_dims] = x0*sin_theta + x1*cos_theta;
} }
template<bool forward, typename T> template <bool forward, typename T, typename D>
static void rope_norm_cuda( static void rope_norm_cuda(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, D * dst,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const int ne0,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { const int ne1,
const int s1,
const int s2,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
...@@ -234,20 +313,34 @@ static void rope_norm_cuda( ...@@ -234,20 +313,34 @@ static void rope_norm_cuda(
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
attn_factor, corr_dims, theta_scale, freq_factors); freq_factors, row_indices, set_rows_stride);
} else { } else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>( rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
attn_factor, corr_dims, theta_scale, freq_factors); freq_factors, row_indices, set_rows_stride);
} }
} }
template<bool forward, typename T> template <bool forward, typename T, typename D>
static void rope_neox_cuda( static void rope_neox_cuda(const T * x,
const T * x, T * dst, const int ne0, const int ne1, const int s1, const int s2, const int n_dims, const int nr, D * dst,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const int ne0,
const rope_corr_dims corr_dims, const float * freq_factors, cudaStream_t stream) { const int ne1,
const int s1,
const int s2,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
...@@ -256,13 +349,13 @@ static void rope_neox_cuda( ...@@ -256,13 +349,13 @@ static void rope_neox_cuda(
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_neox<forward, false, T><<<block_nums, block_dims, 0, stream>>>( rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
attn_factor, corr_dims, theta_scale, freq_factors); freq_factors, row_indices, set_rows_stride);
} else { } else {
rope_neox<forward, true, T><<<block_nums, block_dims, 0, stream>>>( rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
attn_factor, corr_dims, theta_scale, freq_factors); freq_factors, row_indices, set_rows_stride);
} }
} }
...@@ -270,7 +363,7 @@ template<bool forward, typename T> ...@@ -270,7 +363,7 @@ template<bool forward, typename T>
static void rope_multi_cuda( static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0); GGML_ASSERT(ne0 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
...@@ -281,11 +374,11 @@ static void rope_multi_cuda( ...@@ -281,11 +374,11 @@ static void rope_multi_cuda(
if (freq_factors == nullptr) { if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>( rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections); attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else { } else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>( rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections); attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} }
} }
...@@ -315,7 +408,9 @@ static void rope_vision_cuda( ...@@ -315,7 +408,9 @@ static void rope_vision_cuda(
} }
template <bool forward> template <bool forward>
void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
ggml_tensor * dst,
const ggml_tensor * set_rows = nullptr) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1]; const ggml_tensor * src1 = dst->src[1];
const ggml_tensor * src2 = dst->src[2]; const ggml_tensor * src2 = dst->src[2];
...@@ -323,12 +418,25 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) ...@@ -323,12 +418,25 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;
const float * src1_d = (const float *)src1->data; const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data; void * dst_d = dst->data;
const int64_t * row_indices = nullptr;
ggml_type dst_type = dst->type;
int set_rows_stride = 0;
if (set_rows != nullptr) {
GGML_ASSERT(forward);
dst_d = set_rows->data;
row_indices = (const int64_t *) set_rows->src[1]->data;
dst_type = set_rows->type;
set_rows_stride = set_rows->nb[1] / ggml_type_size(set_rows->type);
}
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16);
GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16);
GGML_ASSERT(src0->type == dst->type); // When not fused, src0 and dst types must match
// When fused (ROPE+VIEW+SET_ROWS), src0 may be F32 and dst may be F16
GGML_ASSERT(src0->type == dst->type || (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16));
const int64_t ne00 = src0->ne[0]; // head dims const int64_t ne00 = src0->ne[0]; // head dims
const int64_t ne01 = src0->ne[1]; // num heads const int64_t ne01 = src0->ne[1]; // num heads
...@@ -363,6 +471,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) ...@@ -363,6 +471,7 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
const bool is_imrope = mode == GGML_ROPE_TYPE_IMROPE;
const bool is_vision = mode == GGML_ROPE_TYPE_VISION; const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
if (is_mrope) { if (is_mrope) {
...@@ -385,14 +494,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) ...@@ -385,14 +494,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
// compute // compute
if (is_neox) { if (is_neox) {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_cuda<forward>( rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward>( rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
...@@ -400,11 +513,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) ...@@ -400,11 +513,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>( rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>( rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
...@@ -421,14 +534,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) ...@@ -421,14 +534,18 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} else { } else {
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward>( rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
(const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward>( rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
(const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, pos, freq_scale, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, stream); freq_factors, row_indices, set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
...@@ -442,3 +559,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ...@@ -442,3 +559,7 @@ void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_rope_impl<false>(ctx, dst); ggml_cuda_op_rope_impl<false>(ctx, dst);
} }
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * rope, ggml_tensor * set_rows) {
ggml_cuda_op_rope_impl<true>(ctx, rope, set_rows);
}
...@@ -5,3 +5,5 @@ ...@@ -5,3 +5,5 @@
void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rope(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_rope_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rope_fused(ggml_backend_cuda_context & ctx, ggml_tensor * dst, ggml_tensor * set_rows);
...@@ -4,30 +4,53 @@ ...@@ -4,30 +4,53 @@
typedef void (*set_rows_kernel_t)(const char * src, char * dst); typedef void (*set_rows_kernel_t)(const char * src, char * dst);
// Generic quantized set_rows kernel template // Generic quantized set_rows kernel template
template<typename idx_t, typename block_type, int qk, void (*quantize_func)(const float*, block_type*)> template <typename idx_t, typename block_type, int qk, void (*quantize_func)(const float *, block_type *)>
static __global__ void k_set_rows_quant( static __global__ void k_set_rows_quant(const float * __restrict__ src0,
const float * __restrict__ src0, const idx_t * __restrict__ src1, block_type * __restrict__ dst, const idx_t * __restrict__ src1,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, block_type * __restrict__ dst,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const int64_t ne_total,
const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne10,
const int64_t s10, const int64_t s11, const int64_t s12, const int64_t ne11,
const int64_t s1, const int64_t s2, const int64_t s3) { const int64_t ne12,
const int64_t ne13,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t s10,
const int64_t s11,
const int64_t s12,
const int64_t s1,
const int64_t s2,
const int64_t s3,
const uint3 ne00,
const uint3 ne01,
const uint3 ne02,
const uint3 ne11_fd,
const uint3 ne12_fd) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = (ne00 * ne01 * ne02 * ne03) / qk;
if (i >= ne_total) { if (i >= ne_total) {
return; return;
} }
const int64_t i_base = i * qk; const int64_t i_base = i * qk;
const int64_t i03 = i_base / (ne00 * ne01 * ne02); uint32_t tmp = (uint32_t) i_base;
const int64_t i02 = (i_base - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); uint2 div_mod;
const int64_t i01 = (i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i_base - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; div_mod = fast_div_modulo(tmp, ne00);
const int64_t i00 = div_mod.y;
tmp = div_mod.x;
const int64_t i12 = i03 % ne12; div_mod = fast_div_modulo(tmp, ne01);
const int64_t i11 = i02 % ne11; const int64_t i01 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne02);
const int64_t i02 = div_mod.y;
const int64_t i03 = div_mod.x;
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01; const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
...@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant( ...@@ -41,6 +64,8 @@ static __global__ void k_set_rows_quant(
quantize_func(src_block, dst_block); quantize_func(src_block, dst_block);
GGML_UNUSED(ne10); GGML_UNUSED(ne10);
GGML_UNUSED(ne11);
GGML_UNUSED(ne12);
GGML_UNUSED(ne13); GGML_UNUSED(ne13);
} }
...@@ -71,40 +96,65 @@ static void set_rows_cuda_quant( ...@@ -71,40 +96,65 @@ static void set_rows_cuda_quant(
const int64_t s2 = nb2; const int64_t s2 = nb2;
const int64_t s3 = nb3; const int64_t s3 = nb3;
if (ne_total > 0) { if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>( k_set_rows_quant<idx_t, block_type, qk, quantize_func><<<grid_size, block_size, 0, stream>>>(
src0_d, src1_d, dst_d, src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01, s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd,
ne00, ne01, ne02, ne03, ne01_fd, ne02_fd, ne11_fd, ne12_fd);
ne10, ne11, ne12, ne13,
s01, s02, s03,
s10, s11, s12,
s1, s2, s3);
} }
} }
template<typename src_t, typename idx_t, typename dst_t> template <typename src_t, typename idx_t, typename dst_t>
static __global__ void k_set_rows( static __global__ void k_set_rows(const src_t * __restrict__ src0,
const src_t * __restrict__ src0, const idx_t * __restrict__ src1, dst_t * __restrict__ dst, const idx_t * __restrict__ src1,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, dst_t * __restrict__ dst,
const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t ne13, const int64_t ne_total,
const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne10,
const int64_t s10, const int64_t s11, const int64_t s12, const int64_t ne11,
const int64_t s1, const int64_t s2, const int64_t s3) { const int64_t ne12,
const int64_t ne13,
const int64_t s01,
const int64_t s02,
const int64_t s03,
const int64_t s10,
const int64_t s11,
const int64_t s12,
const int64_t s1,
const int64_t s2,
const int64_t s3,
const uint3 ne00,
const uint3 ne01,
const uint3 ne02,
const uint3 ne11_fd,
const uint3 ne12_fd) {
const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x; const int64_t i = int64_t(blockDim.x) * blockIdx.x + threadIdx.x;
const int64_t ne_total = ne00 * ne01 * ne02 * ne03;
if (i >= ne_total) { if (i >= ne_total) {
return; return;
} }
const int64_t i03 = i / (ne00 * ne01 * ne02); uint32_t tmp = (uint32_t) i;
const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); uint2 div_mod;
const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01) / ne00;
const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne00 * ne01 - i01 * ne00; div_mod = fast_div_modulo(tmp, ne00);
const int64_t i00 = div_mod.y;
tmp = div_mod.x;
const int64_t i12 = i03 % ne12; div_mod = fast_div_modulo(tmp, ne01);
const int64_t i11 = i02 % ne11; const int64_t i01 = div_mod.y;
tmp = div_mod.x;
div_mod = fast_div_modulo(tmp, ne02);
const int64_t i02 = div_mod.y;
const int64_t i03 = div_mod.x;
const int64_t i12 = fastmodulo((uint32_t) i03, ne12_fd);
const int64_t i11 = fastmodulo((uint32_t) i02, ne11_fd);
const int64_t i10 = i01; const int64_t i10 = i01;
const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12); const int64_t dst_row = *(src1 + i10*s10 + i11*s11 + i12*s12);
...@@ -115,6 +165,8 @@ static __global__ void k_set_rows( ...@@ -115,6 +165,8 @@ static __global__ void k_set_rows(
dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]); dst_row_ptr[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
GGML_UNUSED(ne10); GGML_UNUSED(ne10);
GGML_UNUSED(ne11);
GGML_UNUSED(ne12);
GGML_UNUSED(ne13); GGML_UNUSED(ne13);
} }
...@@ -144,14 +196,16 @@ static void set_rows_cuda( ...@@ -144,14 +196,16 @@ static void set_rows_cuda(
const int64_t s2 = nb2/sizeof(dst_t); const int64_t s2 = nb2/sizeof(dst_t);
const int64_t s3 = nb3/sizeof(dst_t); const int64_t s3 = nb3/sizeof(dst_t);
if (ne_total > 0) { if (ne_total > 0 && ne00 > 0 && ne01 > 0 && ne02 > 0 && ne11 > 0 && ne12 > 0) {
k_set_rows<<<grid_size, block_size, 0, stream>>>( const uint3 ne00_fd = init_fastdiv_values((uint32_t) ne00);
src0_d, src1_d, dst_d, const uint3 ne01_fd = init_fastdiv_values((uint32_t) ne01);
ne00, ne01, ne02, ne03, const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
ne10, ne11, ne12, ne13, const uint3 ne11_fd = init_fastdiv_values((uint32_t) ne11);
s01, s02, s03, const uint3 ne12_fd = init_fastdiv_values((uint32_t) ne12);
s10, s11, s12,
s1, s2, s3); k_set_rows<<<grid_size, block_size, 0, stream>>>(src0_d, src1_d, dst_d, ne_total, ne10, ne11, ne12, ne13, s01,
s02, s03, s10, s11, s12, s1, s2, s3, ne00_fd, ne01_fd, ne02_fd,
ne11_fd, ne12_fd);
} }
} }
......
#include "set.cuh"
#include "cpy.cuh"
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT((src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32));
GGML_ASSERT(src1->type == src0->type);
GGML_ASSERT(dst ->type == src0->type);
GGML_ASSERT(ggml_is_contiguous(dst));
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(ggml_is_contiguous(src1));
const size_t nb1 = ((int32_t *) dst->op_params)[0];
const size_t nb2 = ((int32_t *) dst->op_params)[1];
const size_t nb3 = ((int32_t *) dst->op_params)[2];
const size_t offset = ((int32_t *) dst->op_params)[3];
const bool inplace= (bool) ((int32_t *) dst->op_params)[4];
if (!inplace) {
ggml_cuda_cpy(ctx, src0, dst);
}
ggml_tensor dst_view = *dst;
dst_view.data = (void *)((char *)dst->data + offset);
dst_view.ne[0] = src1->ne[0];
dst_view.ne[1] = src1->ne[1];
dst_view.ne[2] = src1->ne[2];
dst_view.ne[3] = src1->ne[3];
dst_view.nb[0] = ggml_element_size(dst);
dst_view.nb[1] = nb1;
dst_view.nb[2] = nb2;
dst_view.nb[3] = nb3;
ggml_cuda_cpy(ctx, src1, &dst_view);
}
#pragma once
#include "common.cuh"
#define CUDA_SET_BLOCK_SIZE 256
void ggml_cuda_op_set(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#include "common.cuh"
#include "ggml.h"
#include "solve_tri.cuh"
#define MAX_N_FAST 64
#define MAX_K_FAST 32
// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
// ======================
// When ncols_template == 0 the bounds for the loops in this function are not
// known and can't be unrolled. As we want to keep pragma unroll for all other
// cases we supress the clang transformation warning here.
#ifdef __clang__
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <int n_template, int k_template>
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n_arg,
const int k_arg) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int col_idx = threadIdx.y;
if (col_idx >= k) {
return;
}
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
__syncthreads();
#pragma unroll
for (int row = 0; row < n; ++row) {
float sum = 0.0f;
{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
sum = warp_reduce_sum(sum);
if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
}
}
}
#ifdef __clang__
# pragma clang diagnostic pop
#endif // __clang__
static void solve_tri_f32_cuda(const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb12,
size_t nb13,
size_t nb2,
size_t nb3,
cudaStream_t stream) {
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
dim3 threads(WARP_SIZE, k);
dim3 grid(ne02 * ne03);
if (n == 64) {
switch (k) {
case 32:
solve_tri_f32_fast<64, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 16:
solve_tri_f32_fast<64, 16>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 14:
solve_tri_f32_fast<64, 14>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 12:
solve_tri_f32_fast<64, 12>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 10:
solve_tri_f32_fast<64, 10>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 8:
solve_tri_f32_fast<64, 8>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 6:
solve_tri_f32_fast<64, 6>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 4:
solve_tri_f32_fast<64, 4>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 2:
solve_tri_f32_fast<64, 2>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 1:
solve_tri_f32_fast<64, 1>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
default:
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else { // run general case
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
ggml_is_contiguous(src0);
ggml_is_contiguous(src1);
const int64_t n = src0->ne[0];
const int64_t k = src1->ne[0];
GGML_ASSERT(n <= 64);
GGML_ASSERT(k <= 32);
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
}
#include "common.cuh"
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-tile.cuh"
DECL_FATTN_TILE_CASE(72, 72);
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
#include "ggml.h" #include "ggml.h"
#include "topk-moe.cuh" #include "topk-moe.cuh"
#include <cmath>
#include <initializer_list> #include <initializer_list>
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
...@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * ...@@ -63,7 +64,8 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float * weights, float * weights,
int32_t * ids, int32_t * ids,
const int n_rows, const int n_rows,
const int n_expert_used) { const int n_expert_used,
const float clamp_val) {
const int row = blockIdx.x * blockDim.y + threadIdx.y; const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) { if (row >= n_rows) {
return; return;
...@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * ...@@ -139,6 +141,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
if constexpr (with_norm) { if constexpr (with_norm) {
wt_sum = warp_reduce_sum(wt_sum); wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum; const float inv_sum = 1.0f / wt_sum;
for (int i = 0; i < experts_per_thread; i++) { for (int i = 0; i < experts_per_thread; i++) {
...@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * ...@@ -157,6 +160,10 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
weights[idx] = output_weights[i]; weights[idx] = output_weights[i];
} }
} }
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
} }
template <bool with_norm, bool delayed_softmax = false> template <bool with_norm, bool delayed_softmax = false>
...@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, ...@@ -166,9 +173,9 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
int32_t * ids, int32_t * ids,
const int n_rows, const int n_rows,
const int n_expert, const int n_expert,
const int n_expert_used) { const int n_expert_used,
const float clamp_val) {
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization");
const int rows_per_block = 4; const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1);
...@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, ...@@ -177,43 +184,43 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) { switch (n_expert) {
case 1: case 1:
topk_moe_cuda<1, with_norm, delayed_softmax> topk_moe_cuda<1, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 2: case 2:
topk_moe_cuda<2, with_norm, delayed_softmax> topk_moe_cuda<2, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 4: case 4:
topk_moe_cuda<4, with_norm, delayed_softmax> topk_moe_cuda<4, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 8: case 8:
topk_moe_cuda<8, with_norm, delayed_softmax> topk_moe_cuda<8, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 16: case 16:
topk_moe_cuda<16, with_norm, delayed_softmax> topk_moe_cuda<16, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 32: case 32:
topk_moe_cuda<32, with_norm, delayed_softmax> topk_moe_cuda<32, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 64: case 64:
topk_moe_cuda<64, with_norm, delayed_softmax> topk_moe_cuda<64, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 128: case 128:
topk_moe_cuda<128, with_norm, delayed_softmax> topk_moe_cuda<128, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 256: case 256:
topk_moe_cuda<256, with_norm, delayed_softmax> topk_moe_cuda<256, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
case 512: case 512:
topk_moe_cuda<512, with_norm, delayed_softmax> topk_moe_cuda<512, with_norm, delayed_softmax>
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used); <<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val);
break; break;
default: default:
GGML_ASSERT(false && "fatal error"); GGML_ASSERT(false && "fatal error");
...@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ...@@ -226,7 +233,8 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
ggml_tensor * weights, ggml_tensor * weights,
ggml_tensor * ids, ggml_tensor * ids,
const bool with_norm, const bool with_norm,
const bool delayed_softmax) { const bool delayed_softmax,
ggml_tensor * clamp) {
GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
...@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ...@@ -242,18 +250,25 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const int n_expert_used = weights->ne[1]; const int n_expert_used = weights->ne[1];
float clamp_val = -INFINITY;
if (with_norm) { if (with_norm) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); if (clamp) {
clamp_val = ggml_get_op_params_f32(clamp, 0);
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val);
} else { } else {
GGML_ASSERT(clamp == nullptr);
if (delayed_softmax) { if (delayed_softmax) {
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else { } else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used); launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} }
} }
} }
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights) { bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp) {
float scale = 1.0f; float scale = 1.0f;
float max_bias = 0.0f; float max_bias = 0.0f;
...@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso ...@@ -279,13 +294,26 @@ bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tenso
return false; return false;
} }
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
return false;
}
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) {
return false;
}
}
return true; return true;
} }
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) { std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_DIV, GGML_OP_RESHAPE }; GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS }; GGML_OP_VIEW, GGML_OP_GET_ROWS };
......
...@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, ...@@ -8,8 +8,9 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
ggml_tensor * weights, ggml_tensor * weights,
ggml_tensor * ids, ggml_tensor * ids,
const bool with_norm, const bool with_norm,
const bool delayed_softmax = false); const bool delayed_softmax = false,
ggml_tensor * weight_clamp = nullptr);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights); bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, const ggml_tensor * weights, const ggml_tensor * clamp = nullptr);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false); std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);
...@@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) { ...@@ -18,10 +18,7 @@ static __device__ __forceinline__ float op_step(float x) {
} }
static __device__ __forceinline__ float op_gelu(float x) { static __device__ __forceinline__ float op_gelu(float x) {
const float GELU_COEF_A = 0.044715f; return ggml_cuda_op_gelu_single(x);
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
} }
static __device__ __forceinline__ float op_gelu_erf(float x) { static __device__ __forceinline__ float op_gelu_erf(float x) {
...@@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) { ...@@ -37,7 +34,7 @@ static __device__ __forceinline__ float op_gelu_quick(float x) {
} }
static __device__ __forceinline__ float op_silu(float x) { static __device__ __forceinline__ float op_silu(float x) {
return x / (1.0f + expf(-x)); return ggml_cuda_op_silu_single(x);
} }
static __device__ __forceinline__ float op_tanh(float x) { static __device__ __forceinline__ float op_tanh(float x) {
...@@ -84,10 +81,34 @@ static __device__ __forceinline__ float op_log(float x) { ...@@ -84,10 +81,34 @@ static __device__ __forceinline__ float op_log(float x) {
return logf(x); return logf(x);
} }
static __device__ __forceinline__ float op_expm1(float x) {
return expm1f(x);
}
static __device__ __forceinline__ float op_softplus(float x) {
return (x > 20.0f) ? x : logf(1.0f + expf(x));
}
static __device__ __forceinline__ float op_elu(float x) { static __device__ __forceinline__ float op_elu(float x) {
return (x > 0.f) ? x : expm1f(x); return (x > 0.f) ? x : expm1f(x);
} }
static __device__ __forceinline__ float op_floor(float x) {
return floorf(x);
}
static __device__ __forceinline__ float op_ceil(float x) {
return ceilf(x);
}
static __device__ __forceinline__ float op_round(float x) {
return round(x);
}
static __device__ __forceinline__ float op_trunc(float x) {
return trunc(x);
}
template <float (*op)(float), typename T> template <float (*op)(float), typename T>
static __global__ void unary_op_kernel(const T * x, T * dst, const int k) { static __global__ void unary_op_kernel(const T * x, T * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
...@@ -204,6 +225,30 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ...@@ -204,6 +225,30 @@ void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_elu>(ctx, dst); ggml_cuda_op_unary<op_elu>(ctx, dst);
} }
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_floor>(ctx, dst);
}
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_ceil>(ctx, dst);
}
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_round>(ctx, dst);
}
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_trunc>(ctx, dst);
}
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_expm1>(ctx, dst);
}
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
ggml_cuda_op_unary<op_softplus>(ctx, dst);
}
/* gated ops */ /* gated ops */
template <float (*op)(float), typename T> template <float (*op)(float), typename T>
...@@ -317,13 +362,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons ...@@ -317,13 +362,8 @@ static __global__ void swiglu_oai_kernel(const T * x, const T * g, T * dst, cons
float xi = x[j0]; float xi = x[j0];
float gi = g[j1]; float gi = g[j1];
xi = fminf(xi, limit);
gi = fmaxf(fminf(gi, limit), -limit);
float out_glu = xi / (1.0f + expf(-xi * alpha));
out_glu = out_glu * (1.0f + gi);
dst[i] = out_glu; dst[i] = ggml_cuda_op_swiglu_oai_single(xi, gi, alpha, limit);
} }
template <typename T> template <typename T>
......
#pragma once
#include "common.cuh" #include "common.cuh"
#define CUDA_NEG_BLOCK_SIZE 256 #define CUDA_NEG_BLOCK_SIZE 256
...@@ -60,8 +61,20 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst); ...@@ -60,8 +61,20 @@ void ggml_cuda_op_cos(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_log(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_expm1(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_softplus(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_elu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_floor(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_ceil(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_round(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_trunc(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_reglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
...@@ -75,3 +88,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); ...@@ -75,3 +88,23 @@ void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
__device__ __forceinline__ float ggml_cuda_op_silu_single(float x) {
return x / (1.0f + expf(-x));
}
__device__ __forceinline__ float ggml_cuda_op_gelu_single(float x) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
return 0.5f * x * (1.0f + tanhf(SQRT_2_OVER_PI * x * (1.0f + GELU_COEF_A * x * x)));
}
__device__ __forceinline__ float ggml_cuda_op_swiglu_oai_single(float x, float g, float alpha = 1.702f, float limit = 7.0f) {
x = fminf(x, limit);
g = fmaxf(fminf(g, limit), -limit);
float out_glu = x / (1.0f + expf(-x * alpha));
out_glu = out_glu * (1.0f + g);
return out_glu;
}
...@@ -81,6 +81,70 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst, ...@@ -81,6 +81,70 @@ static __global__ void upscale_f32_bilinear(const float * x, float * dst,
dst[index] = result; dst[index] = result;
} }
namespace bicubic_interpolation {
// https://en.wikipedia.org/wiki/Bicubic_interpolation#Bicubic_convolution_algorithm
__device__ const float a = -0.75f; // use alpha = -0.75 (same as PyTorch)
static __device__ float weight1(float x) { return ((a + 2) * x - (a + 3)) * x * x + 1; };
static __device__ float weight2(float x) { return ((a * x - 5 * a) * x + 8 * a) * x - 4 * a; };
static __device__ float bicubic(float p0, float p1, float p2, float p3, float x) {
const float w0 = weight2(x + 1);
const float w1 = weight1(x + 0);
const float w2 = weight1(1 - x);
const float w3 = weight2(2 - x);
return p0 * w0 + p1 * w1 + p2 * w2 + p3 * w3;
};
} // namespace bicubic_interpolation
static __global__ void upscale_f32_bicubic(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset) {
using bicubic_interpolation::bicubic;
const int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
const int64_t dst_total_elements = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
if (index >= dst_total_elements) {
return;
}
const int i10_dst = index % ne10_dst;
const int i11_dst = (index / ne10_dst) % ne11_dst;
const int i12_dst = (index / (ne10_dst * ne11_dst)) % ne12_dst;
const int i13_dst = index / (ne10_dst * ne11_dst * ne12_dst);
const int i02_src = (int)(i12_dst / sf2);
const int i03_src = (int)(i13_dst / sf3);
const float y_src_f = ((float)i11_dst + pixel_offset) / sf1 - pixel_offset;
const int y0_src = (int)floorf(y_src_f);
const float dy = y_src_f - (float)y0_src;
const float x_src_f = ((float)i10_dst + pixel_offset) / sf0 - pixel_offset;
const int x0_src = (int)floorf(x_src_f);
const float dx = x_src_f - (float)x0_src;
const char * x_base = (const char *)x + (int64_t)i02_src * nb02 + (int64_t)i03_src * nb03;
auto load = [=](int x_off, int y_off) -> float {
int i00_src = max(0, min(x0_src + x_off, ne00_src - 1));
int i01_src = max(0, min(y0_src + y_off, ne01_src - 1));
return *(const float *)(x_base + (int64_t)i00_src * nb00 + (int64_t)i01_src * nb01);
};
const float result = bicubic(
bicubic(load(-1,-1), load(0,-1), load(1,-1), load(2,-1), dx),
bicubic(load(-1, 0), load(0, 0), load(1, 0), load(2, 0), dx),
bicubic(load(-1, 1), load(0, 1), load(1, 1), load(2, 1), dx),
bicubic(load(-1, 2), load(0, 2), load(1, 2), load(2, 2), dx), dy);
dst[index] = result;
}
static void upscale_f32_cuda(const float * x, float * dst, static void upscale_f32_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03, const int nb00, const int nb01, const int nb02, const int nb03,
const int ne10, const int ne11, const int ne12, const int ne13, const int ne10, const int ne11, const int ne12, const int ne13,
...@@ -104,6 +168,18 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst, ...@@ -104,6 +168,18 @@ static void upscale_f32_bilinear_cuda(const float * x, float * dst,
upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset); upscale_f32_bilinear<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
} }
static void upscale_f32_bicubic_cuda(const float * x, float * dst,
const int nb00, const int nb01, const int nb02, const int nb03,
const int ne00_src, const int ne01_src,
const int ne10_dst, const int ne11_dst, const int ne12_dst, const int ne13_dst,
const float sf0, const float sf1, const float sf2, const float sf3,
const float pixel_offset, cudaStream_t stream) {
const int64_t dst_size = ne10_dst * ne11_dst * ne12_dst * ne13_dst;
const int64_t num_blocks = (dst_size + CUDA_UPSCALE_BLOCK_SIZE - 1) / CUDA_UPSCALE_BLOCK_SIZE;
upscale_f32_bicubic<<<num_blocks, CUDA_UPSCALE_BLOCK_SIZE,0,stream>>>(x, dst, nb00, nb01, nb02, nb03, ne00_src, ne01_src, ne10_dst, ne11_dst, ne12_dst, ne13_dst, sf0, sf1, sf2, sf3, pixel_offset);
}
void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data; const float * src0_d = (const float *)src0->data;
...@@ -121,17 +197,22 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { ...@@ -121,17 +197,22 @@ void ggml_cuda_op_upscale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float sf2 = (float)dst->ne[2]/src0->ne[2]; float sf2 = (float)dst->ne[2]/src0->ne[2];
const float sf3 = (float)dst->ne[3]/src0->ne[3]; const float sf3 = (float)dst->ne[3]/src0->ne[3];
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = dst->ne[0] > 1 && src0->ne[0] > 1 ? (float)(dst->ne[0] - 1) / (src0->ne[0] - 1) : sf0;
sf1 = dst->ne[1] > 1 && src0->ne[1] > 1 ? (float)(dst->ne[1] - 1) / (src0->ne[1] - 1) : sf1;
pixel_offset = 0.0f;
}
if (mode == GGML_SCALE_MODE_NEAREST) { if (mode == GGML_SCALE_MODE_NEAREST) {
upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream); upscale_f32_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], sf0, sf1, sf2, sf3, stream);
} else if (mode == GGML_SCALE_MODE_BILINEAR) { } else if (mode == GGML_SCALE_MODE_BILINEAR) {
float pixel_offset = 0.5f;
if (mode_flags & GGML_SCALE_FLAG_ALIGN_CORNERS) {
sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1);
sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1);
pixel_offset = 0.0f;
}
upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], upscale_f32_bilinear_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream); sf0, sf1, sf2, sf3, pixel_offset, stream);
} else if (mode == GGML_SCALE_MODE_BICUBIC) {
upscale_f32_bicubic_cuda(src0_d, dst_d, src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
src0->ne[0], src0->ne[1], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
sf0, sf1, sf2, sf3, pixel_offset, stream);
} }
} }
...@@ -109,7 +109,7 @@ ...@@ -109,7 +109,7 @@
#define cudaStreamNonBlocking hipStreamNonBlocking #define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamPerThread hipStreamPerThread #define cudaStreamPerThread hipStreamPerThread
#define cudaStreamSynchronize hipStreamSynchronize #define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event, flags) hipStreamWaitEvent(stream, event, flags) #define cudaStreamWaitEvent hipStreamWaitEvent
#define cudaGraphExec_t hipGraphExec_t #define cudaGraphExec_t hipGraphExec_t
#define cudaGraphNode_t hipGraphNode_t #define cudaGraphNode_t hipGraphNode_t
#define cudaKernelNodeParams hipKernelNodeParams #define cudaKernelNodeParams hipKernelNodeParams
......
...@@ -29,10 +29,11 @@ if (CXX_IS_HIPCC) ...@@ -29,10 +29,11 @@ if (CXX_IS_HIPCC)
endif() endif()
else() else()
# Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES. # Forward (AMD)GPU_TARGETS to CMAKE_HIP_ARCHITECTURES.
if(AMDGPU_TARGETS AND NOT GPU_TARGETS)
set(GPU_TARGETS ${AMDGPU_TARGETS})
endif()
if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES) if(GPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS}) set(CMAKE_HIP_ARCHITECTURES ${GPU_TARGETS})
elseif(AMDGPU_TARGETS AND NOT CMAKE_HIP_ARCHITECTURES)
set(CMAKE_HIP_ARCHITECTURES ${AMDGPU_TARGETS})
endif() endif()
cmake_minimum_required(VERSION 3.21) cmake_minimum_required(VERSION 3.21)
enable_language(HIP) enable_language(HIP)
......
...@@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) { ...@@ -102,7 +102,7 @@ static bool ggml_op_is_empty(enum ggml_op op) {
} }
} }
static inline float ggml_softplus(float input) { static inline float ggml_compute_softplus_f32(float input) {
return (input > 20.0f) ? input : logf(1 + expf(input)); return (input > 20.0f) ? input : logf(1 + expf(input));
} }
// //
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment