Unverified Commit 943464cc authored by Jeffrey Morgan's avatar Jeffrey Morgan Committed by GitHub
Browse files

llama: update to commit 71e90e88 (#10192)

parent 369de832
......@@ -406,6 +406,15 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
#endif // CP_ASYNC_AVAILABLE
#else
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_KV);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(stride_mask); GGML_UNUSED(jt); GGML_UNUSED(tile_K);
GGML_UNUSED(tile_V); GGML_UNUSED(tile_mask); GGML_UNUSED(Q_B);
GGML_UNUSED(VKQ_C); GGML_UNUSED(KQ_max); GGML_UNUSED(KQ_rowsum);
GGML_UNUSED(kb0);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
......@@ -797,6 +806,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
__syncthreads();
}
#else
GGML_UNUSED(Q_f2); GGML_UNUSED(K_h2); GGML_UNUSED(V_h2);
GGML_UNUSED(mask_h2); GGML_UNUSED(dstk); GGML_UNUSED(dstk_fixup);
GGML_UNUSED(scale); GGML_UNUSED(slope); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(stride_Q1);
GGML_UNUSED(stride_Q2); GGML_UNUSED(stride_KV); GGML_UNUSED(stride_mask);
GGML_UNUSED(jt); GGML_UNUSED(kb0_start); GGML_UNUSED(kb0_stop);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
......@@ -931,6 +946,16 @@ static __global__ void flash_attn_ext_f16(
(Q_f2, K_h2, V_h2, mask_h2, dstk, dst_meta, scale, slope, logit_softcap,
ne01, ne02, stride_Q1, stride_Q2, stride_KV, stride_mask, jt, kb0_start_kernel, kb0_stop_kernel);
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(NEW_MMA_AVAILABLE)
}
......@@ -970,7 +995,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
fattn_kernel = flash_attn_ext_f16<D, ncols1, ncols2, nwarps, KQ_per_iter, ntiles, use_logit_softcap>;
}
launch_fattn<D, ncols1, ncols2, 0, KQ_per_iter>(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, true, true);
launch_fattn<D, ncols1, ncols2, KQ_per_iter>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, FATTN_KQ_STRIDE, true, true, true);
}
......@@ -984,38 +1010,38 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/4, 4); \
extern DECL_FATTN_MMA_F16_CASE(D, (ncols)/8, 8); \
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64);
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 8)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 16)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 32)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 64)
DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 64)
// Kernels with ncols == 128 are only 4% faster due to register pressure.
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128);
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128); // Needs too much shared memory.
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 64, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 80, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2( 96, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(112, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(128, 128)
// DECL_FATTN_MMA_F16_CASE_ALL_NCOLS2(256, 128) // Needs too much shared memory.
......@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
......@@ -58,18 +58,17 @@ static __global__ void flash_attn_tile_ext_f16(
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
......@@ -105,8 +104,7 @@ static __global__ void flash_attn_tile_ext_f16(
__syncthreads();
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F16;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F16) {
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols/nwarps];
......@@ -271,24 +269,36 @@ static __global__ void flash_attn_tile_ext_f16(
const int i0 = i00 + 2*threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j);
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = __high2float(dst_val);
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = __low2float(dst_val);
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = __high2float(dst_val);
}
if (parallel_blocks != 1 && threadIdx.x == 0) {
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
template <int cols_per_block, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
......@@ -296,15 +306,17 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
......@@ -324,37 +336,22 @@ void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_ten
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
}
......@@ -4,7 +4,7 @@
#define FATTN_KQ_STRIDE_TILE_F32 32
template<int D, int ncols, int nwarps, int parallel_blocks, bool use_logit_softcap> // D == head size
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
......@@ -52,24 +52,35 @@ static __global__ void flash_attn_tile_ext_f32(
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
return;
}
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.y + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.y / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const float2 * Q_f2 = (const float2 *) (Q + nb02* blockIdx.z + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb12*(blockIdx.z / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
......@@ -103,8 +114,7 @@ static __global__ void flash_attn_tile_ext_f32(
__syncthreads();
const int k_start = parallel_blocks == 1 ? 0 : ip*FATTN_KQ_STRIDE_TILE_F32;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE_TILE_F32) {
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
......@@ -269,25 +279,37 @@ static __global__ void flash_attn_tile_ext_f32(
const int i0 = i00 + 2*threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/(2*WARP_SIZE)];
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.y + D*blockIdx.y + i0 + 1] = dst_val.y;
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 0] = dst_val.x;
dst[j_dst*D*gridDim.z + D*blockIdx.z + i0 + 1] = dst_val.y;
}
if (parallel_blocks != 1 && threadIdx.x == 0) {
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int cols_per_block, int parallel_blocks, bool use_logit_softcap>
template <int cols_per_block, bool use_logit_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
......@@ -295,15 +317,17 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1, -1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
......@@ -320,37 +344,22 @@ void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_ten
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
constexpr int parallel_blocks = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, parallel_blocks, use_logit_softcap>(ctx, dst);
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
}
#include "common.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
......@@ -55,17 +55,16 @@ static __global__ void flash_attn_vec_ext_f16(
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f16_t dequantize_1_v = get_dequantize_1_f16(type_V);
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.y + nb01*ic0;
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio);
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio);
const half * maskh = (const half *) mask + ne11*ic0;
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
......@@ -172,8 +171,7 @@ static __global__ void flash_attn_vec_ext_f16(
half2 VKQ[ncols] = {{0.0f, 0.0f}};
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
// For unknown reasons using a half array of size 1 for kqmax_new causes a performance regression,
......@@ -283,29 +281,41 @@ static __global__ void flash_attn_vec_ext_f16(
kqsum[j_VKQ] = warp_reduce_sum((float)kqsum[j_VKQ]);
half dst_val = (__low2half(VKQ[j_VKQ]) + __high2half(VKQ[j_VKQ]));
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
NO_DEVICE_CODE;
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
}
template <int D, ggml_type type_K, ggml_type type_V>
......@@ -325,65 +335,48 @@ void ggml_cuda_flash_attn_ext_vec_f16_case(ggml_backend_cuda_context & ctx, ggml
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
constexpr int cols_per_block = 2;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
constexpr int cols_per_block = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
constexpr int cols_per_block = 8;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f16_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}
......
#include "common.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
template<int D, int ncols, ggml_type type_K, ggml_type type_V, bool use_logit_softcap> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
......@@ -45,6 +45,18 @@ static __global__ void flash_attn_vec_ext_f32(
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02);
GGML_UNUSED(ne03); GGML_UNUSED(ne10); GGML_UNUSED(ne11);
GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12);
GGML_UNUSED(nb13); GGML_UNUSED(nb21); GGML_UNUSED(nb22);
GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
return;
}
......@@ -55,16 +67,15 @@ static __global__ void flash_attn_vec_ext_f32(
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
constexpr dequantize_1_f32_t dequantize_1_v = get_dequantize_1_f32(type_V);
const int ic0 = (blockIdx.x / parallel_blocks) * ncols; // Index of the Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
Q += nb02* blockIdx.y + nb01*ic0;
K += nb12*(blockIdx.y / gqa_ratio);
V += nb22*(blockIdx.y / gqa_ratio); // K and V have same shape
Q += nb02* blockIdx.z + nb01*ic0;
K += nb12*(blockIdx.z / gqa_ratio);
V += nb22*(blockIdx.z / gqa_ratio); // K and V have same shape
const half * maskh = (const half *) mask + ne11*ic0;
const float slope = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slope = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
constexpr int nwarps = D / WARP_SIZE;
......@@ -115,7 +126,7 @@ static __global__ void flash_attn_vec_ext_f32(
// Set memory to zero if out of bounds:
if (ncols > 2 && ic0 + j >= ne01) {
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
tmp_q_i32[i] = 0;
......@@ -128,7 +139,7 @@ static __global__ void flash_attn_vec_ext_f32(
const float * Q_f = (const float *) (Q + j*nb01);
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
quantize_q8_1_to_shared<float2>(Q_f + 4*i0, scale, tmp_q_i32, tmp_q_ds);
}
}
......@@ -141,7 +152,7 @@ static __global__ void flash_attn_vec_ext_f32(
float2 * tmp_q_ds = (float2 *) (tmp_q_i32 + D/sizeof(int));
#pragma unroll
for (int i0 = 0; i0 < D/sizeof(int); i0 += WARP_SIZE) {
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
Q_i32[j][i0/WARP_SIZE] = tmp_q_i32[i];
......@@ -167,8 +178,7 @@ static __global__ void flash_attn_vec_ext_f32(
float VKQ[ncols] = {0.0f};
const int k_start = parallel_blocks == 1 ? 0 : ip*D;
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
for (int k_VKQ_0 = blockIdx.y*D; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*D) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new_arr[ncols];
......@@ -268,29 +278,39 @@ static __global__ void flash_attn_vec_ext_f32(
kqsum[j_VKQ] = warp_reduce_sum(kqsum[j_VKQ]);
float dst_val = VKQ[j_VKQ];
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= kqsum[j_VKQ];
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
dst[j_dst*D*gridDim.y + D*blockIdx.y + tid] = dst_val;
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
dst[j_dst*D*gridDim.z + D*blockIdx.z + tid] = dst_val;
}
if (parallel_blocks != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[(ic0 + tid)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = make_float2(kqmax[tid], kqsum[tid]);
if (gridDim.y != 1 && tid < ncols && (ncols <= 2 || ic0 + tid < ne01)) {
dst_meta[((ic0 + tid)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = make_float2(kqmax[tid], kqsum[tid]);
}
#else
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap); GGML_UNUSED(ne00);
GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03); GGML_UNUSED(ne10);
GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13); GGML_UNUSED(ne31);
GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02); GGML_UNUSED(nb03);
GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13); GGML_UNUSED(nb21);
GGML_UNUSED(nb22); GGML_UNUSED(nb23); GGML_UNUSED(ne0); GGML_UNUSED(ne1);
GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int D, int cols_per_block, int parallel_blocks, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
template <int D, int cols_per_block, ggml_type type_K, ggml_type type_V, bool use_logit_softcap>
void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
constexpr int nwarps = D/WARP_SIZE;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, D, need_f16_K, need_f16_V, false);
}
template <int D, ggml_type type_K, ggml_type type_V>
......@@ -307,65 +327,48 @@ void ggml_cuda_flash_attn_ext_vec_f32_case(ggml_backend_cuda_context & ctx, ggml
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] == 1) {
constexpr int cols_per_block = 1;
constexpr int parallel_blocks = 4;
constexpr int cols_per_block = 1;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] == 2) {
constexpr int cols_per_block = 2;
constexpr int parallel_blocks = 4;
constexpr int cols_per_block = 2;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 4) {
constexpr int cols_per_block = 4;
constexpr int parallel_blocks = 4;
constexpr int cols_per_block = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
if (Q->ne[1] <= 8) {
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 4;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 8;
constexpr int parallel_blocks = 1;
constexpr int cols_per_block = 8;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>(ctx, dst);
ggml_cuda_flash_attn_ext_vec_f32_case_impl<D, cols_per_block, type_K, type_V, use_logit_softcap>(ctx, dst);
}
}
......
......@@ -7,14 +7,19 @@
#include "fattn-wmma-f16.cuh"
#ifdef FP16_MMA_AVAILABLE
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#include <mma.h>
namespace wmma = nvcuda::wmma;
#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)
#undef HIP_ENABLE_WARP_SYNC_BUILTINS // conflicts with rocWMMA headers
#include <rocwmma/rocwmma.hpp>
namespace wmma = rocwmma;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
template<int D, int ncols, int nwarps, int VKQ_stride, typename KQ_acc_t, bool use_logit_softcap>
__launch_bounds__(nwarps*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
......@@ -51,7 +56,7 @@ static __global__ void flash_attn_ext_f16(
const int ne1,
const int ne2,
const int ne3) {
#if defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
......@@ -60,19 +65,20 @@ static __global__ void flash_attn_ext_f16(
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
const int ic0 = ncols*blockIdx.x; // Index of the first Q/QKV column to work on.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
......@@ -84,16 +90,16 @@ static __global__ void flash_attn_ext_f16(
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const float * Q_f = (const float *) (Q + nb02* blockIdx.z + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.z / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.z / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const float slopef = get_alibi_slope(max_bias, blockIdx.z, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
......@@ -132,9 +138,9 @@ static __global__ void flash_attn_ext_f16(
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
if (i0 + warp_size > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
......@@ -146,9 +152,9 @@ static __global__ void flash_attn_ext_f16(
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
if (i0 + warp_size > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
......@@ -162,34 +168,34 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
wmma::fill_fragment(KQ_c[j], static_cast<KQ_acc_t>(0.0f));
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, wmma::mem_col_major);
}
}
......@@ -202,27 +208,27 @@ static __global__ void flash_attn_ext_f16(
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
float KQ_f_tmp[FATTN_KQ_STRIDE / warp_size];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
KQ_f_tmp[k0/warp_size] = KQ_f[j*kqs_padded + k];
if (use_logit_softcap) {
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
KQ_f_tmp[k0/warp_size] = logit_softcap*tanhf(KQ_f_tmp[k0/warp_size]);
}
}
float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
KQ_f_tmp[k0/warp_size] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/warp_size]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
KQ_max_new = warp_reduce_max<warp_size>(KQ_max_new);
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
KQ_max_scale_f[j0/nwarps] = expf(diff);
......@@ -233,48 +239,48 @@ static __global__ void flash_attn_ext_f16(
float KQ_rowsum_add = 0.0f;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += warp_size) {
const int k = k0 + threadIdx.x;
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
const float diff = KQ_f_tmp[k0/warp_size] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/warp_size] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
KQ_f_tmp[k0/warp_size] = 0.0f;
}
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
KQ_rowsum_add += KQ_f_tmp[k0/warp_size];
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/warp_size];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*warp_size)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
KQ2_tmp[k0/warp_size] = KQ2[j*(kqs_padded/2) + k];
if (use_logit_softcap) {
// There is no dedicated tangens hyperbolicus function for half2.
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
KQ2_tmp[k0/warp_size] = h2exp(KQ2_tmp[k0/warp_size]*make_half2(2.0f, 2.0f));
KQ2_tmp[k0/warp_size] = (KQ2_tmp[k0/warp_size] - make_half2(1.0f, 1.0f))
/(KQ2_tmp[k0/warp_size] + make_half2(1.0f, 1.0f));
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
KQ2_tmp[k0/warp_size] *= logit_softcap_2;
}
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
KQ2_tmp[k0/warp_size] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/warp_size]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
KQ_max_new = __half2half2(warp_reduce_max<warp_size>(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
......@@ -283,17 +289,17 @@ static __global__ void flash_attn_ext_f16(
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += warp_size) {
const int k = k0 + threadIdx.x;
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
const half2 diff = KQ2_tmp[k0/warp_size] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/warp_size] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
*((uint32_t *) &KQ2_tmp[k0/warp_size]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/warp_size];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/warp_size];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
KQ_rowsum_add = warp_reduce_sum<warp_size>(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
......@@ -308,7 +314,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync(
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
......@@ -320,7 +326,7 @@ static __global__ void flash_attn_ext_f16(
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], static_cast<half>(0.0f));
}
#pragma unroll
......@@ -328,10 +334,10 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
......@@ -343,10 +349,10 @@ static __global__ void flash_attn_ext_f16(
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync(
wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major);
D_padded, wmma::mem_col_major);
}
}
......@@ -364,9 +370,9 @@ static __global__ void flash_attn_ext_f16(
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
if (i0 + warp_size > D/2 && i >= D/2) {
break;
}
......@@ -388,7 +394,7 @@ static __global__ void flash_attn_ext_f16(
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
const int j_dst = (ic0 + j_VKQ)*gridDim.y + blockIdx.y;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
......@@ -398,19 +404,19 @@ static __global__ void flash_attn_ext_f16(
}
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
for (int i0 = 0; i0 < D; i0 += warp_size) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
if (i0 + warp_size > D && i >= D) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (parallel_blocks == 1) {
if (gridDim.y == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
dst[j_dst*gridDim.z*D + blockIdx.z*D + i] = dst_val;
}
if (parallel_blocks == 1 || threadIdx.x != 0) {
if (gridDim.y == 1 || threadIdx.x != 0) {
continue;
}
......@@ -421,11 +427,21 @@ static __global__ void flash_attn_ext_f16(
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
dst_meta[((ic0 + j_VKQ)*gridDim.z + blockIdx.z) * gridDim.y + blockIdx.y] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
GGML_UNUSED(Q); GGML_UNUSED(K); GGML_UNUSED(V); GGML_UNUSED(mask);
GGML_UNUSED(dst); GGML_UNUSED(dst_meta); GGML_UNUSED(scale);
GGML_UNUSED(max_bias); GGML_UNUSED(m0); GGML_UNUSED(m1);
GGML_UNUSED(n_head_log2); GGML_UNUSED(logit_softcap);
GGML_UNUSED(ne00); GGML_UNUSED(ne01); GGML_UNUSED(ne02); GGML_UNUSED(ne03);
GGML_UNUSED(ne10); GGML_UNUSED(ne11); GGML_UNUSED(ne12); GGML_UNUSED(ne13);
GGML_UNUSED(ne31); GGML_UNUSED(nb31); GGML_UNUSED(nb01); GGML_UNUSED(nb02);
GGML_UNUSED(nb03); GGML_UNUSED(nb11); GGML_UNUSED(nb12); GGML_UNUSED(nb13);
GGML_UNUSED(nb21); GGML_UNUSED(nb22); GGML_UNUSED(nb23);
GGML_UNUSED(ne0); GGML_UNUSED(ne1); GGML_UNUSED(ne2); GGML_UNUSED(ne3);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE)))
}
constexpr int get_max_power_of_2(int x) {
......@@ -455,59 +471,26 @@ static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 1, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
launch_fattn<D, cols_per_block, 1, -1>(ctx, dst, fattn_kernel, nwarps, 0, FATTN_KQ_STRIDE, true, true, false, warp_size);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
......@@ -515,6 +498,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
const ggml_tensor * Q = dst->src[0];
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
const int warp_size = ggml_cuda_info().devices[ctx.device].warp_size;
if (prec != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
......@@ -571,7 +555,8 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
return;
}
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
if (Q->ne[1] <= 8 && Q->ne[0] % warp_size == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
......@@ -592,6 +577,7 @@ void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_ten
}
return;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
......
......@@ -250,10 +250,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
ggml_cuda_set_device(ctx.device);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const int warp_size = ggml_cuda_info().devices[ggml_cuda_get_device()].warp_size;
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (cc >= GGML_CUDA_CC_OFFSET_AMD) {
if (GGML_CUDA_CC_IS_AMD(cc)) {
#if defined(GGML_HIP_ROCWMMA_FATTN)
if (fp16_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
#endif // defined(GGML_HIP_ROCWMMA_FATTN)
// On AMD the tile kernels perform poorly, use the vec kernel instead:
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
......@@ -273,13 +281,13 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
if (!fp16_mma_available(cc)) {
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
}
} else {
if (Q->ne[1] <= 8) {
if (Q->ne[1] <= 8 || Q->ne[0] == 256) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
......@@ -288,21 +296,21 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
return;
}
const int gqa_ratio = Q->ne[2] / K->ne[2];
const bool mma_fast_for_bs1 = fp16_mma_available(cc) && gqa_ratio % 2 == 0 &&
K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16 && mask;
if (Q->ne[1] == 1 && Q->ne[0] % (2*WARP_SIZE) == 0 && !mma_fast_for_bs1) {
const bool gqa_opt_applies = ((Q->ne[2] / K->ne[2]) % 2 == 0) && mask; // The mma-based kernels have GQA-specific optimizations
const bool mma_needs_data_conversion = K->type != GGML_TYPE_F16 || V->type != GGML_TYPE_F16;
const bool mma_faster_for_bs1 = new_mma_available(cc) && gqa_opt_applies && cc < GGML_CUDA_CC_ADA_LOVELACE && !mma_needs_data_conversion;
const bool can_use_vector_kernel = Q->ne[0] % (2*warp_size) == 0;
if (Q->ne[1] == 1 && can_use_vector_kernel && !mma_faster_for_bs1) {
if (prec == GGML_PREC_DEFAULT) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
return;
} else if(Q->ne[0] <= 128) {
} else {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
return;
}
return;
}
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (cc == GGML_CUDA_CC_VOLTA) {
if (fp16_mma_available(cc) && !new_mma_available(cc)) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
......
......@@ -31,12 +31,14 @@
#include "ggml-cuda/rope.cuh"
#include "ggml-cuda/scale.cuh"
#include "ggml-cuda/softmax.cuh"
#include "ggml-cuda/ssm-conv.cuh"
#include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/unary.cuh"
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh"
#include "ggml-cuda/wkv.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml.h"
......@@ -262,9 +264,11 @@ static ggml_cuda_device_info ggml_cuda_init() {
id, prop.name, prop.gcnArchName, info.devices[id].cc & 0xffff,
device_vmm ? "yes" : "no", prop.warpSize);
#elif defined(GGML_USE_MUSA)
// TODO: refine the .cc to reflect MUSA's actual CC capabilities
// FIXME: Ensure compatibility with varying warp sizes across different MUSA archs.
info.devices[id].warp_size = 32;
info.devices[id].smpbo = prop.sharedMemPerBlockOptin;
info.devices[id].cc = 100*prop.major + 10*prop.minor;
info.devices[id].cc = GGML_CUDA_CC_OFFSET_MTHREADS + prop.major * 0x100;
info.devices[id].cc += prop.minor * 0x10;
GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s\n",
id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no");
#else
......@@ -541,12 +545,12 @@ static void * ggml_backend_cuda_buffer_get_base(ggml_backend_buffer_t buffer) {
return ctx->dev_ptr;
}
static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
static enum ggml_status ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
ggml_backend_cuda_buffer_context * ctx = (ggml_backend_cuda_buffer_context *)buffer->context;
if (tensor->view_src != NULL) {
assert(tensor->view_src->buffer->buft == buffer->buft);
return;
return GGML_STATUS_SUCCESS;
}
if (ggml_is_quantized(tensor->type) && tensor->view_src == nullptr && ggml_backend_buffer_get_usage(buffer) != GGML_BACKEND_BUFFER_USAGE_COMPUTE) {
......@@ -559,6 +563,7 @@ static void ggml_backend_cuda_buffer_init_tensor(ggml_backend_buffer_t buffer, g
CUDA_CHECK(cudaMemset((char *)tensor->data + original_size, 0, padded_size - original_size));
}
}
return GGML_STATUS_SUCCESS;
}
static void ggml_backend_cuda_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
......@@ -794,7 +799,7 @@ static void * ggml_backend_cuda_split_buffer_get_base(ggml_backend_buffer_t buff
GGML_UNUSED(buffer);
}
static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
static enum ggml_status ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) {
GGML_ASSERT(tensor->view_src == nullptr); // views of split tensors are not supported
ggml_backend_cuda_split_buffer_context * ctx = (ggml_backend_cuda_split_buffer_context *)buffer->context;
......@@ -840,6 +845,7 @@ static void ggml_backend_cuda_split_buffer_init_tensor(ggml_backend_buffer_t buf
}
}
tensor->extra = extra;
return GGML_STATUS_SUCCESS;
}
static void ggml_backend_cuda_split_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
......@@ -1187,11 +1193,39 @@ static void ggml_cuda_op_mul_mat_cublas(
// ldc == nrows of the matrix that cuBLAS writes into
int64_t ldc = id == ctx.device ? ne0 : row_diff;
const int compute_capability = ggml_cuda_info().devices[id].cc;
const int cc = ggml_cuda_info().devices[id].cc;
const bool use_fp16 = (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1] && dst->op_params[0] == GGML_PREC_DEFAULT;
if (compute_capability >= GGML_CUDA_CC_VOLTA && use_fp16) {
if (src0->type == GGML_TYPE_BF16 && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
ggml_cuda_pool_alloc<nv_bfloat16> src1_as_bf16(ctx.pool(id));
if (src1->type != GGML_TYPE_BF16) {
const to_bf16_cuda_t to_bf16_cuda = ggml_get_to_bf16_cuda(src1->type);
GGML_ASSERT(to_bf16_cuda != nullptr);
size_t ne = src1_ncols*ne10;
src1_as_bf16.alloc(ne);
to_bf16_cuda(src1_ddf_i, src1_as_bf16.get(), ne, stream);
}
const nv_bfloat16 * src1_ptr = src1->type == GGML_TYPE_BF16 ? (const nv_bfloat16 *) src1_ddf_i : src1_as_bf16.get();
const nv_bfloat16 * src0_ptr = (const nv_bfloat16 *)src0_dd_i;
ggml_cuda_pool_alloc<nv_bfloat16> dst_bf16(ctx.pool(id), row_diff*src1_ncols);
const float alpha_f32 = 1.0f;
const float beta_f32 = 0.0f;
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
CUBLAS_CHECK(
cublasGemmEx(ctx.cublas_handle(id), CUBLAS_OP_T, CUBLAS_OP_N,
row_diff, src1_ncols, ne10,
&alpha_f32, src0_ptr, CUDA_R_16BF, ne00,
src1_ptr, CUDA_R_16BF, ne10,
&beta_f32, dst_bf16.get(), CUDA_R_16BF, ldc,
CUBLAS_COMPUTE_32F,
CUBLAS_GEMM_DEFAULT_TENSOR_OP));
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_BF16);
to_fp32_cuda(dst_bf16.get(), dst_dd_i, row_diff*src1_ncols, stream);
} else if (((GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || GGML_CUDA_CC_IS_AMD(cc)) && use_fp16) {
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
ggml_cuda_pool_alloc<half> src0_as_f16(ctx.pool(id));
if (src0->type != GGML_TYPE_F16) {
......@@ -1215,7 +1249,7 @@ static void ggml_cuda_op_mul_mat_cublas(
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(
......@@ -1758,7 +1792,9 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
beta = &beta_f32;
}
if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
if (GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
cu_compute_type = CUBLAS_COMPUTE_32F;
alpha = &alpha_f32;
beta = &beta_f32;
......@@ -1835,7 +1871,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
}
#endif
if (dst->op_params[0] == GGML_PREC_DEFAULT) {
if (dst->op_params[0] == GGML_PREC_DEFAULT && cu_data_type == CUDA_R_16F) {
const to_fp32_cuda_t to_fp32_cuda = ggml_get_to_fp32_cuda(GGML_TYPE_F16);
to_fp32_cuda(dst_f16.get(), dst_ddf, ne_dst, main_stream);
}
......@@ -2148,6 +2184,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(dst)) {
case GGML_UNARY_OP_ABS:
ggml_cuda_op_abs(ctx, dst);
break;
case GGML_UNARY_OP_SGN:
ggml_cuda_op_sgn(ctx, dst);
break;
case GGML_UNARY_OP_NEG:
ggml_cuda_op_neg(ctx, dst);
break;
......@@ -2191,6 +2233,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GROUP_NORM:
ggml_cuda_op_group_norm(ctx, dst);
break;
case GGML_OP_L2_NORM:
ggml_cuda_op_l2_norm(ctx, dst);
break;
case GGML_OP_CONCAT:
ggml_cuda_op_concat(ctx, dst);
break;
......@@ -2248,6 +2293,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CLAMP:
ggml_cuda_op_clamp(ctx, dst);
break;
case GGML_OP_LOG:
ggml_cuda_op_log(ctx, dst);
break;
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
......@@ -2284,6 +2332,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst);
break;
case GGML_OP_SSM_CONV:
ggml_cuda_op_ssm_conv(ctx, dst);
break;
case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst);
break;
case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst);
break;
......@@ -2301,6 +2355,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_GATED_LINEAR_ATTN:
ggml_cuda_op_gated_linear_attn(ctx, dst);
break;
case GGML_OP_RWKV_WKV7:
ggml_cuda_op_rwkv_wkv7(ctx, dst);
break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
ggml_cuda_cross_entropy_loss_back(ctx, dst);
break;
......@@ -2568,7 +2625,7 @@ static void maintain_cuda_graph(ggml_backend_cuda_context * cuda_ctx, std::vecto
for (size_t i = 0; i < cuda_ctx->cuda_graph->num_nodes; i++) {
if(count(ggml_cuda_cpy_fn_ptrs.begin(), ggml_cuda_cpy_fn_ptrs.end(), cuda_ctx->cuda_graph->params[i].func) > 0) {
char ** updated_kernel_arg_ptr = cuda_ctx->cuda_graph->updated_kernel_arg.at(k++);
cuda_ctx->cuda_graph->params[i].kernelParams[1] = updated_kernel_arg_ptr;
*(void**)cuda_ctx->cuda_graph->params[i].kernelParams[1] = *(void**)updated_kernel_arg_ptr;
CUDA_CHECK(cudaGraphKernelNodeSetParams(cuda_ctx->cuda_graph->nodes[i], &cuda_ctx->cuda_graph->params[i]));
}
}
......@@ -2607,13 +2664,15 @@ static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx,
static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info;
#ifdef __HIP_PLATFORM_AMD__
hipGraphNode_t errorNode;
hipError_t stat = hipGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#else
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &result_info);
#endif
#else
cudaGraphNode_t errorNode;
cudaGraphExecUpdateResult result_info;
cudaError_t stat = cudaGraphExecUpdate(cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, &errorNode, &result_info);
#endif // CUDART_VERSION >= 12000
if (stat == cudaErrorGraphExecUpdateFailure) {
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: CUDA graph update failed\n", __func__);
......@@ -2968,6 +3027,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
switch (op->op) {
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_ABS:
case GGML_UNARY_OP_SGN:
case GGML_UNARY_OP_NEG:
case GGML_UNARY_OP_STEP:
case GGML_UNARY_OP_GELU:
......@@ -3071,6 +3132,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_BF16) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) {
return true;
}
......@@ -3150,10 +3214,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return false;
} break;
case GGML_OP_SILU_BACK:
return ggml_is_contiguous(op->src[0]);
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
......@@ -3174,6 +3239,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
case GGML_OP_LOG:
case GGML_OP_SSM_SCAN:
case GGML_OP_SSM_CONV:
return true;
case GGML_OP_CONT:
return op->src[0]->type != GGML_TYPE_BF16;
......@@ -3201,6 +3269,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_PAD:
case GGML_OP_UNPAD:
case GGML_OP_ARANGE:
......@@ -3208,11 +3277,22 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_LEAKY_RELU:
case GGML_OP_RWKV_WKV6:
case GGML_OP_GATED_LINEAR_ATTN:
case GGML_OP_RWKV_WKV7:
return true;
case GGML_OP_FLASH_ATTN_EXT: {
#ifndef FLASH_ATTN_AVAILABLE
return false;
#endif // FLASH_ATTN_AVAILABLE
if (op->src[1]->ne[0] != op->src[2]->ne[0]) {
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[3] != 1) {
return false;
}
if (op->src[1]->type == GGML_TYPE_BF16 || op->src[2]->type == GGML_TYPE_BF16) {
return false;
}
......
......@@ -26,6 +26,7 @@ static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "=r"(ret) : "r"(x));
#else
GGML_UNUSED(x);
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
return ret;
......@@ -178,6 +179,7 @@ namespace ggml_cuda_mma {
: "l"(xs));
#else
load_generic(xs0, stride);
GGML_UNUSED(t);
#endif // NEW_MMA_AVAILABLE
}
......
......@@ -27,8 +27,8 @@ void ggml_cuda_op_mul_mat_q(
// The stream-k decomposition is only faster for recent NVIDIA GPUs.
// Also its fixup needs to allocate a temporary buffer in the memory pool.
// There are multiple parallel CUDA streams for src1_ncols != ne11 which would introduce a race condition for this buffer.
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA &&
cc < GGML_CUDA_CC_OFFSET_AMD && src1_ncols == ne11;
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) &&
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && src1_ncols == ne11;
const mmq_args args = {src0_dd_i, src1_ddq_i, dst_dd_i, ne00, row_diff, stride00, src1_padded_row_size, src1_ncols, ne11, nrows_dst, use_stream_k};
switch (src0->type) {
......@@ -145,7 +145,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return true;
#endif //GGML_CUDA_FORCE_MMQ
if (cc < GGML_CUDA_CC_OFFSET_AMD) {
if (GGML_CUDA_CC_IS_NVIDIA(cc)) {
return !fp16_mma_hardware_available(cc) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
......
......@@ -90,7 +90,7 @@ struct tile_x_sizes {
static int get_mmq_x_max_host(const int cc) {
return new_mma_available(cc) ? 128 :
ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD ?
GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ?
#ifdef GGML_CUDA_FORCE_MMQ
128 : 64;
#else
......@@ -109,9 +109,9 @@ static constexpr __device__ int get_mmq_x_max_device() {
#if __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#ifdef GGML_CUDA_FORCE_MMQ
return MMQ_DP4A_MAX_BATCH_SIZE;
#else // GGML_CUDA_FORCE_MMQ
return 128;
#else // GGML_CUDA_FORCE_MMQ
return MMQ_DP4A_MAX_BATCH_SIZE;
#endif // GGML_CUDA_FORCE_MMQ
#else // __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
......@@ -123,8 +123,8 @@ static constexpr __device__ int get_mmq_x_max_device() {
}
static int get_mmq_y_host(const int cc) {
return cc >= GGML_CUDA_CC_OFFSET_AMD ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
(ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA ? 128 : 64);
return GGML_CUDA_CC_IS_AMD(cc) ? (GGML_CUDA_CC_IS_RDNA1(cc) ? 64 : 128) :
((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) ? 128 : 64);
}
static constexpr __device__ int get_mmq_y_device() {
......@@ -945,7 +945,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_16_q8_1_mma(
}
}
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
......@@ -1024,7 +1024,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
}
#pragma unroll
for (int k01 = 0; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
for (int k01 = 0; k01 < WARP_SIZE/2; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
......@@ -1035,19 +1035,34 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_dp4a(
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (k01 < WARP_SIZE/2) {
constexpr int ns = 2;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
} else {
constexpr int ns = 1;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
constexpr int ns = 2;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
}
}
// Some compilers fail to unroll the loop over k01 if there is a conditional statement for ns in the inner loop.
// As a workaround 2 separate loops are used instead.
#pragma unroll
for (int k01 = WARP_SIZE/2; k01 < WARP_SIZE; k01 += QR2_K*VDR_Q2_K_Q8_1_MMQ) {
const int k0 = k00 + k01;
#pragma unroll
for (int j0 = 0; j0 < mmq_x; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < mmq_y; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
constexpr int ns = 1;
sum[j0/nwarps*mmq_y/WARP_SIZE + i0/WARP_SIZE] += vec_dot_q2_K_q8_1_impl_mmq<ns>(
&x_qs[i*(2*WARP_SIZE + 1) + k0], &y_qs[j*MMQ_TILE_Y_K + k01],
&x_dm[i*(WARP_SIZE + 1) + k0/4], k01 < WARP_SIZE/2 ? y_df[j0/nwarps].x : y_df[j0/nwarps].y,
&y_ds[j*MMQ_TILE_Y_K + (1 + k01/QI8_1)]);
}
}
}
......@@ -1176,7 +1191,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma(
}
}
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
......@@ -1253,7 +1268,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const float d = bxi->d;
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
x_df[i*MMQ_MMA_TILE_X_K_Q3_K + sizeof(int)*(threadIdx.x % (WARP_SIZE/8)) + l] = d*sc8[l];
}
#else
......@@ -1376,7 +1391,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
}
}
......@@ -1517,7 +1532,7 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
const half2 dm = bxi->dm * make_half2(1.0f, -1.0f);
#pragma unroll
for (int l = 0; l < sizeof(int); ++l) {
for (int l = 0; l < int(sizeof(int)); ++l) {
x_dm[i*MMQ_MMA_TILE_X_K_Q8_1 + sizeof(int)*ksc + l] = dm*make_half2(sc8[l], m8[l]);
}
}
......@@ -1810,7 +1825,7 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma(
}
}
#else
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum);
GGML_UNUSED(x); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k00);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
......@@ -2570,6 +2585,8 @@ static __device__ void mul_mat_q_process_tile(
} else {
write_back(sum, dst + jt*mmq_x*ne0 + it*mmq_y, ne0, tile_x_max_i, tile_y_max_j);
}
GGML_UNUSED(ne00); GGML_UNUSED(ne10);
}
......@@ -2695,7 +2712,7 @@ static __global__ void mul_mat_q_stream_k_fixup(
const int it = (kbc_stop - jt*(blocks_per_ne00*nty)) / blocks_per_ne00;
// Skip fixup tile if it's unrelated to the output tile assigned to this CUDA block:
if (it != blockIdx.x || jt != blockIdx.y) {
if ((unsigned)it != blockIdx.x || (unsigned)jt != blockIdx.y) {
continue;
}
......@@ -2772,14 +2789,14 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
const int shmem = mmq_get_shmem<type>(mmq_x, mmq_y, cc);
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
static bool shmem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!shmem_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, false>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_q<type, mmq_x, MMQ_NWARPS, true>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem));
shmem_limit_raised[id] = true;
}
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && !defined(GGML_USE_MUSA)
const int nty = (args.ne01 + mmq_y - 1) / mmq_y;
const int ntx = (args.ne11 + mmq_x - 1) / mmq_x;
......@@ -2825,14 +2842,13 @@ static void launch_mul_mat_q(ggml_backend_cuda_context & ctx, const mmq_args & a
template <ggml_type type>
void mul_mat_q_case(ggml_backend_cuda_context & ctx, const mmq_args & args, cudaStream_t stream) {
const int id = ggml_cuda_get_device();
const int nsm = ggml_cuda_info().devices[id].nsm;
const int cc = ggml_cuda_info().devices[id].cc;
const int smpbo = ggml_cuda_info().devices[id].smpbo;
const int mmq_x_max = get_mmq_x_max_host(cc);
const int mmq_y = get_mmq_y_host(cc);
const int block_num_y = (args.ne01 + mmq_y - 1) / mmq_y;
const bool use_stream_k = ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA && cc < GGML_CUDA_CC_OFFSET_AMD;
const bool use_stream_k = GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA;
int mmq_x_best = 0;
int nparts_best = INT_MAX;
......
......@@ -29,7 +29,7 @@ static __global__ void mul_mat_vec(
__syncthreads();
}
float sumf;
float sumf = 0.0f;
if constexpr (std::is_same<T, half>::value) {
const half2 * x2 = (const half2 *) x;
......
......@@ -47,11 +47,89 @@ static constexpr __device__ int get_vdr_mmvq(ggml_type type) {
1;
}
enum mmvq_parameter_table_id {
MMVQ_PARAMETERS_GENERIC = 0,
MMVQ_PARAMETERS_GCN,
MMVQ_PARAMETERS_RDNA2
};
static constexpr __device__ mmvq_parameter_table_id get_device_table_id() {
#if defined(RDNA2) || defined(RDNA3) || defined(RDNA4)
return MMVQ_PARAMETERS_RDNA2;
#elif defined(GCN) || defined(CDNA)
return MMVQ_PARAMETERS_GCN;
#else
return MMVQ_PARAMETERS_GENERIC;
#endif
}
static __host__ mmvq_parameter_table_id get_device_table_id(int cc) {
if (GGML_CUDA_CC_IS_RDNA2(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_RDNA4(cc)) {
return MMVQ_PARAMETERS_RDNA2;
}
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
return MMVQ_PARAMETERS_GCN;
}
return MMVQ_PARAMETERS_GENERIC;
}
static constexpr __host__ __device__ int calc_nwarps(int ncols_y, mmvq_parameter_table_id table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC) {
switch (ncols_y) {
case 1:
case 2:
case 3:
case 4:
return 4;
case 5:
case 6:
case 7:
case 8:
return 2;
default:
return 1;
}
} else if (table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_y) {
case 1:
case 2:
case 3:
case 4:
return 2;
case 5:
case 6:
case 7:
case 8:
default:
return 1;
}
}
return 1;
}
static constexpr __host__ __device__ int calc_rows_per_block(int ncols_y, int table_id) {
if (table_id == MMVQ_PARAMETERS_GENERIC || table_id == MMVQ_PARAMETERS_GCN) {
switch (ncols_y) {
case 1:
return 1;
case 2:
case 3:
case 4:
case 5:
case 6:
case 7:
case 8:
return 2;
default:
return 1;
}
}
return 1;
}
template <ggml_type type, int ncols_y>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
// tell the compiler to use as many registers as it wants, see nwarps definition below
__launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(calc_nwarps(ncols_y, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
......@@ -59,25 +137,21 @@ static __global__ void mul_mat_vec_q(
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
constexpr int vdr = get_vdr_mmvq(type);
constexpr mmvq_parameter_table_id table_id = get_device_table_id();
constexpr int nwarps = calc_nwarps(ncols_y, table_id);
constexpr int rows_per_cuda_block = calc_rows_per_block(ncols_y, table_id);
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type);
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4))
constexpr int nwarps = 1;
constexpr int rows_per_cuda_block = 1;
#else
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3) && !defined(RDNA4)
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
const int tid = warp_size*threadIdx.y + threadIdx.x;
const int row0 = rows_per_cuda_block*blockIdx.x;
const int blocks_per_row_x = ncols_x / qk;
const int blocks_per_col_y = nrows_y / QK8_1;
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {0.0f};
// partial sum for each thread
float tmp[ncols_y][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = (const block_q8_1 *) vy;
......@@ -96,7 +170,7 @@ static __global__ void mul_mat_vec_q(
}
}
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][WARP_SIZE];
__shared__ float tmp_shared[nwarps-1 > 0 ? nwarps-1 : 1][ncols_y][rows_per_cuda_block][warp_size];
if (threadIdx.y > 0) {
#pragma unroll
for (int j = 0; j < ncols_y; ++j) {
......@@ -120,13 +194,22 @@ static __global__ void mul_mat_vec_q(
for (int l = 0; l < nwarps-1; ++l) {
tmp[j][i] += tmp_shared[l][j][i][threadIdx.x];
}
tmp[j][i] = warp_reduce_sum(tmp[j][i]);
tmp[j][i] = warp_reduce_sum<warp_size>(tmp[j][i]);
}
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < nrows_dst)) {
if (threadIdx.x < rows_per_cuda_block && (rows_per_cuda_block == 1 || row0 + threadIdx.x < (unsigned)nrows_dst)) {
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
}
}
GGML_UNUSED(nrows_x);
}
static std::pair<dim3, dim3> calc_launch_params(const int ncols_y, const int nrows_x, const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_y, table_id) - 1) / calc_rows_per_block(ncols_y, table_id);
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(warp_size, calc_nwarps(ncols_y, table_id), 1);
return {block_nums, block_dims};
}
template <ggml_type type>
......@@ -137,65 +220,67 @@ static void mul_mat_vec_q_cuda(
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE);
int id = ggml_cuda_get_device();
int64_t nwarps = 1;
int64_t rows_per_cuda_block = 1;
if (ggml_cuda_info().devices[id].cc < GGML_CUDA_CC_RDNA2) { // NVIDIA and AMD older than RDNA2
switch(ncols_y) {
case 1:
nwarps = 4;
rows_per_cuda_block = 1;
break;
case 2:
case 3:
case 4:
nwarps = 4;
rows_per_cuda_block = 2;
break;
case 5:
case 6:
case 7:
case 8:
nwarps = 2;
rows_per_cuda_block = 2;
break;
default:
GGML_ABORT("fatal error");
break;
}
}
const int64_t nblocks = (nrows_x + rows_per_cuda_block - 1) / rows_per_cuda_block;
const dim3 block_nums(nblocks, 1, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
switch (ncols_y) {
case 1:
mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 2:
mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 3:
mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 4:
mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 5:
mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 6:
mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 7:
mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
case 8:
mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
{
constexpr int c_ncols_y = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_y, nrows_x, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_y><<<dims.first, dims.second, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst);
break;
}
default:
GGML_ABORT("fatal error");
break;
......
......@@ -201,6 +201,85 @@ static __global__ void rms_norm_back_f32(
}
}
// template <int block_size>
// static __global__ void l2_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
// const int row = blockIdx.x*blockDim.y + threadIdx.y;
// const int tid = threadIdx.x;
// float tmp = 0.0f; // partial sum for thread in warp
// for (int col = tid; col < ncols; col += block_size) {
// const float xi = x[row*ncols + col];
// tmp += xi * xi;
// }
// // sum up partial sums
// tmp = warp_reduce_sum(tmp);
// if (block_size > WARP_SIZE) {
// __shared__ float s_sum[32];
// int warp_id = threadIdx.x / WARP_SIZE;
// int lane_id = threadIdx.x % WARP_SIZE;
// if (lane_id == 0) {
// s_sum[warp_id] = tmp;
// }
// __syncthreads();
// tmp = s_sum[lane_id];
// tmp = warp_reduce_sum(tmp);
// }
// // from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
// const float scale = rsqrtf(fmaxf(tmp, eps * eps));
// for (int col = tid; col < ncols; col += block_size) {
// dst[row*ncols + col] = scale * x[row*ncols + col];
// }
// }
template <int block_size>
static __global__ void l2_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
for (int col = tid; col < ncols; col += block_size) {
const float xi = x[col];
tmp += xi * xi;
}
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = s_sum[lane_id];
tmp = warp_reduce_sum(tmp);
}
// from https://pytorch.org/docs/stable/generated/torch.nn.functional.normalize.html
const float scale = rsqrtf(fmaxf(tmp, eps * eps));
for (int col = tid; col < ncols; col += block_size) {
dst[col] = scale * x[col];
}
}
static void norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
......@@ -248,6 +327,19 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
}
}
static void l2_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
l2_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
l2_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
......@@ -340,3 +432,27 @@ void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * d
rms_norm_back_f32_cuda(grad_d, src0f_d, dst_d, ne00, nrows, eps, stream);
}
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
l2_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
......@@ -7,3 +7,5 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_l2_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
......@@ -14,7 +14,7 @@ static __global__ void pad_f32(const float * x, float * dst, const int ne0, cons
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < ne01 && blockIdx.z < ne02*ne03) {
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
int offset_src =
nidx +
blockIdx.y * ne00 +
......
#include "ssm-conv.cuh"
template <size_t split_d_inner, size_t d_conv>
static __global__ void ssm_conv_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2, const int src1_nb1,
float * __restrict__ dst, const int dst_nb0, const int dst_nb1, const int dst_nb2,
const int64_t n_t) {
GGML_UNUSED(src0_nb0);
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1);
const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block = (float *) ((char *) dst + bidx * dst_nb2 + bidy * split_d_inner * dst_nb0);
const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);
float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}
for (int64_t i = 0; i < n_t; i++) {
float sumf = 0.0f;
if (i == 0) {
for (size_t j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}
template <size_t split_d_inner, size_t d_conv, int64_t split_n_t>
static __global__ void ssm_conv_long_token_f32(const float * __restrict__ src0, const float * __restrict__ src1,
const int src0_nb0, const int src0_nb1, const int src0_nb2,
const int src1_nb1, float * __restrict__ dst, const int dst_nb0,
const int dst_nb1, const int dst_nb2, const int64_t n_t) {
const int tid = threadIdx.x;
const int bidx = blockIdx.x;
const int bidy = blockIdx.y;
const int bidz = blockIdx.z;
const float * x_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * split_d_inner * src0_nb1 +
bidz * split_n_t * src0_nb0);
const float * w_block = (const float *) ((const char *) src1 + bidy * split_d_inner * src1_nb1);
float * y_block =
(float *) ((char *) dst + bidx * dst_nb2 + bidz * split_n_t * dst_nb1 + bidy * split_d_inner * dst_nb0);
const int stride_x = src0_nb1 / sizeof(float);
const int stride_w = src1_nb1 / sizeof(float);
const int stride_y = dst_nb1 / sizeof(float);
float x[d_conv] = { 0.0f };
float w[d_conv] = { 0.0f };
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
w[j] = w_block[tid * stride_w + j];
}
#pragma unroll
for (int64_t i = 0; i < split_n_t; i++) {
if (bidz * split_n_t + i < n_t) {
float sumf = 0.0f;
if (i == 0) {
for (size_t j = 0; j < d_conv; j++) {
x[j] = x_block[tid * stride_x + j];
}
} else {
x[(i - 1) % d_conv] = x_block[tid * stride_x + i + d_conv - 1];
}
#pragma unroll
for (size_t j = 0; j < d_conv; j++) {
sumf += x[(i + j) % d_conv] * w[j];
}
y_block[i * stride_y + tid] = sumf;
}
}
}
static void ssm_conv_f32_cuda(const float * src0, const float * src1, const int src0_nb0, const int src0_nb1,
const int src0_nb2, const int src1_nb1, float * dst, const int dst_nb0, const int dst_nb1,
const int dst_nb2, const int64_t nc, const int64_t nr, const int64_t n_t,
const int64_t n_s, cudaStream_t stream) {
const int threads = 128;
GGML_ASSERT(nr % threads == 0);
if (n_t <= 32) {
const dim3 blocks(n_s, (nr + threads - 1) / threads, 1);
if (nc == 4) {
ssm_conv_f32<threads, 4><<<blocks, threads, 0, stream>>>(src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1,
dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 4 now.");
}
} else {
if (nc == 4) {
const int64_t split_n_t = 32;
dim3 blocks(n_s, (nr + threads - 1) / threads, (n_t + split_n_t - 1) / split_n_t);
ssm_conv_long_token_f32<threads, 4, split_n_t><<<blocks, threads, 0, stream>>>(
src0, src1, src0_nb0, src0_nb1, src0_nb2, src1_nb1, dst, dst_nb0, dst_nb1, dst_nb2, n_t);
} else {
GGML_ABORT("Only support kernel size = 4 right now.");
}
}
}
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // conv_x
const struct ggml_tensor * src1 = dst->src[1]; // conv1d.weight
const int64_t nc = src1->ne[0]; // d_conv
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = dst->ne[1]; // tokens per sequence
const int64_t n_s = dst->ne[2]; // number of sequences in the batch
GGML_ASSERT(dst->ne[0] == nr);
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_conv_f32_cuda(src0_d, src1_d, src0->nb[0], src0->nb[1], src0->nb[2], src1->nb[1], dst_d, dst->nb[0], dst->nb[1],
dst->nb[2], nc, nr, n_t, n_s, stream);
}
#include "common.cuh"
void ggml_cuda_op_ssm_conv(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#include "ssm-scan.cuh"
template <size_t splitD, size_t N>
__global__ void __launch_bounds__(splitD, 2)
ssm_scan_f32(const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
const float * __restrict__ src3, const float * __restrict__ src4, const float * __restrict__ src5,
const int src0_nb1, const int src0_nb2, const int src1_nb0, const int src1_nb1, const int src1_nb2,
const int src1_nb3, const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * __restrict__ dst, const int64_t L) {
GGML_UNUSED(src1_nb0);
GGML_UNUSED(src2_nb0);
const int bidx = blockIdx.x; // split along B
const int bidy = blockIdx.y; // split along D
const int tid = threadIdx.x;
const int wid = tid / 32;
const int wtid = tid % 32;
extern __shared__ float smem[];
const int stride_sA = N + 1;
const int stride_ss0 = N + 1;
float * smem_A = smem;
float * smem_s0 = smem_A + splitD * stride_sA;
const float * s0_block = (const float *) ((const char *) src0 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const float * x_block = (const float *) ((const char *) src1 + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
const float * dt_block = (const float *) ((const char *) src2 + (bidx * src2_nb2) + bidy * splitD * sizeof(float));
const float * A_block = (const float *) ((const char *) src3 + bidy * splitD * src3_nb1);
const float * B_block = (const float *) ((const char *) src4 + (bidx * src4_nb2));
const float * C_block = (const float *) ((const char *) src5 + (bidx * src5_nb2));
float * y_block = (float *) ((char *) dst + (bidx * src1_nb2) + bidy * splitD * sizeof(float));
float * s_block = (float *) ((char *) dst + src1_nb3 + bidx * src0_nb2 + bidy * splitD * src0_nb1);
const int stride_s0 = src0_nb1 / sizeof(float);
const int stride_x = src1_nb1 / sizeof(float);
const int stride_dt = src2_nb1 / sizeof(float);
const int stride_A = src3_nb1 / sizeof(float);
const int stride_B = src4_nb1 / sizeof(float);
const int stride_C = src5_nb1 / sizeof(float);
const int stride_s = stride_s0;
const int stride_y = stride_x;
// can N not be 16? for example 32?
if (N == 16) {
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = A_block[(wid * warpSize + i) * stride_A + wtid];
// todo: bank conflict
// I am always confused with how to use the swizzling method to solve
// bank conflit. Hoping somebody can tell me.
smem_A[(wid * warpSize + i) * stride_sA + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
#pragma unroll
for (size_t i = 0; i < splitD / 4; i += 2) {
float value = s0_block[(wid * warpSize + i) * stride_s0 + wtid];
smem_s0[(wid * warpSize + i) * stride_ss0 + wtid + ((wtid / 16) > 0 ? 1 : 0)] = value;
}
}
__syncthreads();
for (int64_t i = 0; i < L; i++) {
float dt_soft_plus = dt_block[i * stride_dt + tid];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(exp(dt_soft_plus));
}
float x_dt = x_block[i * stride_x + tid] * dt_soft_plus;
float sumf = 0.0f;
#pragma unroll
for (size_t j = 0; j < N; j++) {
float state = (smem_s0[tid * stride_ss0 + j] * expf(dt_soft_plus * smem_A[tid * stride_sA + j])) +
(B_block[i * stride_B + j] * x_dt);
sumf += state * C_block[i * stride_C + j];
if (i == L - 1) {
s_block[tid * stride_s + j] = state;
} else {
smem_s0[tid * stride_ss0 + j] = state;
}
}
__syncthreads();
y_block[i * stride_y + tid] = sumf;
}
}
static void ssm_scan_f32_cuda(const float * src0, const float * src1, const float * src2, const float * src3,
const float * src4, const float * src5, const int src0_nb1, const int src0_nb2,
const int src1_nb0, const int src1_nb1, const int src1_nb2, const int src1_nb3,
const int src2_nb0, const int src2_nb1, const int src2_nb2, const int src3_nb1,
const int src4_nb1, const int src4_nb2, const int src5_nb1, const int src5_nb2,
float * dst, const int64_t N, const int64_t D, const int64_t L, const int64_t B,
cudaStream_t stream) {
const int threads = 128;
// todo: consider D cannot be divided,does this situation exist?
GGML_ASSERT(D % threads == 0);
const dim3 blocks(B, (D + threads - 1) / threads, 1);
const int smem_size = (threads * (N + 1) * 2) * sizeof(float);
if (N == 16) {
ssm_scan_f32<128, 16><<<blocks, threads, smem_size, stream>>>(
src0, src1, src2, src3, src4, src5, src0_nb1, src0_nb2, src1_nb0, src1_nb1, src1_nb2, src1_nb3, src2_nb0,
src2_nb1, src2_nb2, src3_nb1, src4_nb1, src4_nb2, src5_nb1, src5_nb2, dst, L);
} else {
GGML_ABORT("doesn't support N!=16.");
}
}
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0]; // s
const struct ggml_tensor * src1 = dst->src[1]; // x
const struct ggml_tensor * src2 = dst->src[2]; // dt
const struct ggml_tensor * src3 = dst->src[3]; // A
const struct ggml_tensor * src4 = dst->src[4]; // B
const struct ggml_tensor * src5 = dst->src[5]; // C
// const int64_t d_state = src0->ne[0];
// const int64_t d_inner = src0->ne[1];
// const int64_t l = src1->ne[1];
// const int64_t b = src0->ne[2];
const int64_t nc = src0->ne[0]; // d_state
const int64_t nr = src0->ne[1]; // d_inner
const int64_t n_t = src1->ne[1]; // number of tokens per sequence
const int64_t n_s = src0->ne[2]; // number of sequences in the batch
GGML_ASSERT(ggml_nelements(src1) + ggml_nelements(src0) == ggml_nelements(dst));
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT(src1->nb[0] == sizeof(float));
GGML_ASSERT(src2->nb[0] == sizeof(float));
GGML_ASSERT(src3->nb[0] == sizeof(float));
GGML_ASSERT(src4->nb[0] == sizeof(float));
GGML_ASSERT(src5->nb[0] == sizeof(float));
// required for the dot product between s and C
GGML_ASSERT(src0->nb[1] == src0->ne[0] * sizeof(float));
// required for per-sequence offsets for states
GGML_ASSERT(src0->nb[2] == src0->ne[0] * src0->ne[1] * sizeof(float));
// required to get correct offset for state destination (i.e. src1->nb[3])
GGML_ASSERT(src1->nb[3] == src1->ne[0] * src1->ne[1] * src1->ne[2] * sizeof(float));
const float * src0_d = (const float *) src0->data;
const float * src1_d = (const float *) src1->data;
const float * src2_d = (const float *) src2->data;
const float * src3_d = (const float *) src3->data;
const float * src4_d = (const float *) src4->data;
const float * src5_d = (const float *) src5->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
ssm_scan_f32_cuda(src0_d, src1_d, src2_d, src3_d, src4_d, src5_d, src0->nb[1], src0->nb[2], src1->nb[0],
src1->nb[1], src1->nb[2], src1->nb[3], src2->nb[0], src2->nb[1], src2->nb[2], src3->nb[1],
src4->nb[1], src4->nb[2], src5->nb[1], src5->nb[2], dst_d, nc, nr, n_t, n_s, stream);
}
#include "common.cuh"
void ggml_cuda_op_ssm_scan(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment