#include "common.h" #include "vec.h" namespace { // [NOTE] TODO list for this kernel: // 1. tune the value for BLOCK_N // 2. planning for {batches, num_heads, num_kv_splits} // and use actual num_kv_splits for small seq length // 3. try fast impl of `.tanh()` // 4. provide amx kernel for index_gemm_kernel_nn when M = 16 // inline void fill_stub(float* __restrict__ out, float val, int64_t size) { using Vec = at::vec::Vectorized; const Vec data_vec(val); at::vec::map([data_vec](Vec out) { return out = data_vec; }, out, out, size); } template inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) { using bVec = at::vec::Vectorized; using fVec = at::vec::Vectorized; const fVec s_fvec = fVec(s); int64_t d = 0; for (; d <= size - bVec::size(); d += bVec::size()) { fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec; fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec; bVec out_bvec = convert_from_float_ext(a_fvec0, a_fvec1); out_bvec.store(out + d); } for (; d < size; ++d) { out[d] = static_cast(acc[d] * s); } } template inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) { using bVec = at::vec::Vectorized; int64_t d = 0; for (; d <= size - bVec::size(); d += bVec::size()) { bVec out_bvec = bVec::loadu(src + d); out_bvec.store(out + d); } for (; d < size; ++d) { out[d] = src[d]; } } // GEMM handles query @ key (indexed) x scale // A : [M, K] // B : [N, K] indexed // C : [M, N] // template struct tinygemm_kernel_nt { static inline void apply( const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, const index_t* __restrict__ indices, float scale, int64_t lda, int64_t ldb, int64_t ldc, int64_t K, int64_t max_tokens) { for (int64_t m = 0; m < BLOCK_M; ++m) { for (int64_t n = 0; n < BLOCK_N; ++n) { float sum = 0.f; int64_t b_idx = indices[n]; TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); for (int64_t k = 0; k < K; ++k) { sum += scale * static_cast(A[m * lda + k]) * static_cast(B[b_idx * ldb + k]); } C[m * ldc + n] = sum; } } } }; #if defined(CPU_CAPABILITY_AVX512) template struct tinygemm_kernel_nt { static inline void apply( const at::BFloat16* __restrict__ A, const at::BFloat16* __restrict__ B, float* __restrict__ C, const index_t* __restrict__ indices, float scale, int64_t lda, int64_t ldb, int64_t ldc, int64_t K, int64_t max_tokens) { constexpr int ROWS = BLOCK_M; constexpr int COLS = BLOCK_N; __m512bh va; __m512bh vb[COLS]; __m512 vc[ROWS * COLS]; __m512 vscale = _mm512_set1_ps(scale); auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); }; Unroll{}(loadc); // for main loop auto compute = [&](auto i, int64_t k) { constexpr int row = i / COLS; constexpr int col = i % COLS; if constexpr (col == 0) { va = (__m512bh)(_mm512_loadu_si512(A + row * lda + k)); } if constexpr (row == 0) { if constexpr (col + 1 < COLS) { int64_t b_idx_prefetch = indices[col + 1]; _mm_prefetch(B + b_idx_prefetch * ldb + k, _MM_HINT_T0); } int64_t b_idx = indices[col]; TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); vb[col] = (__m512bh)(_mm512_loadu_si512(B + b_idx * ldb + k)); } vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); }; // for remainder auto compute2 = [&](auto i, int64_t k, __mmask32 mask) { constexpr int row = i / COLS; constexpr int col = i % COLS; if constexpr (col == 0) { va = (__m512bh)(_mm512_maskz_loadu_epi16(mask, A + row * lda + k)); } if constexpr (row == 0) { int64_t b_idx = indices[col]; TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); vb[col] = (__m512bh)(_mm512_maskz_loadu_epi16(mask, B + b_idx * ldb + k)); } vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); }; int64_t k = 0; for (; k <= K - 32; k += 32) { Unroll{}(compute, k); } int64_t count = K - k; if (count > 0) { __mmask32 mask = (1ULL << count) - 1; Unroll{}(compute2, k, mask); } auto storec = [&](auto i) { constexpr int row = i / COLS; constexpr int col = i % COLS; C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale)); }; Unroll{}(storec); } }; #endif #define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \ tinygemm_kernel_nt::apply( \ A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens); // this is used when N isn't multiple of 16, // N corresponds to `head_size_v` which should be 16x template inline void tinygemm_kernel_nn_scalar( const float* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, const index_t* __restrict__ indices, const float* __restrict__ scale, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, int64_t max_tokens) { for (int64_t m = 0; m < M; ++m) { for (int64_t n = 0; n < N; ++n) { C[m * ldc + n] *= scale[m]; for (int64_t k = 0; k < K; ++k) { int64_t b_idx = indices[k]; TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); C[m * ldc + n] += A[m * lda + k] * static_cast(B[b_idx * ldb + n]); } } } } // GEMM handles v' * scale + attn @ value (indexed) // A : [M, K] // B : [K, N] indexed // C :[M, N] // template struct tinygemm_kernel_nn { static inline void apply( const float* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, const index_t* __restrict__ indices, const float* __restrict__ scale, int64_t lda, int64_t ldb, int64_t ldc, int64_t K, int64_t max_tokens) { tinygemm_kernel_nn_scalar(A, B, C, indices, scale, BLOCK_M, BLOCK_N, K, lda, ldb, ldc, max_tokens); } }; #if defined(CPU_CAPABILITY_AVX512) template struct tinygemm_kernel_nn { static inline void apply( const float* __restrict__ A, const at::BFloat16* __restrict__ B, float* __restrict__ C, const index_t* __restrict__ indices, const float* __restrict__ scale, int64_t lda, int64_t ldb, int64_t ldc, int64_t K, int64_t max_tokens) { constexpr int ROWS = BLOCK_M; constexpr int COLS = BLOCK_N / 16; __m512 va; __m512 vb[COLS]; __m512 vc[ROWS * COLS]; __m512 vscale; auto loadc = [&](auto i) { constexpr int row = i / COLS; constexpr int col = i % COLS; #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Warray-bounds" if constexpr (col == 0) { vscale = _mm512_set1_ps(scale[row]); } #pragma GCC diagnostic pop vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16); vc[i] = _mm512_mul_ps(vc[i], vscale); }; Unroll{}(loadc); auto compute = [&](auto i, int64_t k) { constexpr int row = i / COLS; constexpr int col = i % COLS; if constexpr (col == 0) { va = _mm512_set1_ps(A[row * lda + k]); } if constexpr (row == 0) { if (k + 1 < K) { int64_t b_idx_prefetch = indices[k + 1]; _mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0); } int64_t b_idx = indices[k]; TORCH_CHECK(b_idx < max_tokens, "token index out of scope!"); // for COLS = 2, 4, 6, 8 use 512 bit load // for COLS = 1, 3, 5, 7 use 256 bit load if constexpr (COLS % 2 == 0) { if constexpr (col % 2 == 0) { __m512i b16 = _mm512_loadu_si512(reinterpret_cast(B + b_idx * ldb + col * 16)); vb[col + 0] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0)); vb[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1)); } } else { __m256i b16 = _mm256_loadu_si256(reinterpret_cast(B + b_idx * ldb + col * 16)); vb[col] = CVT_BF16_TO_FP32(b16); } } vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]); }; for (int64_t k = 0; k < K; ++k) { Unroll{}(compute, k); } auto storec = [&](auto i) { constexpr int row = i / COLS; constexpr int col = i % COLS; _mm512_storeu_ps(C + row * ldc + col * 16, vc[i]); }; Unroll{}(storec); } }; #endif #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ tinygemm_kernel_nn::apply( \ A + mb_start * lda, \ B + nb_start, \ C + mb_start * ldc + nb_start, \ indices, \ scale + mb_start, \ lda, \ ldb, \ ldc, \ K, \ max_tokens); template void index_gemm_kernel_nt( const scalar_t* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, const index_t* __restrict__ indices, float scale, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, int64_t max_tokens) { // pattern: 1-8-8 if (M == 1) { constexpr int64_t BLOCK_N = 8; const int64_t NB = div_up(N, BLOCK_N); int64_t mb_start = 0, lda = 1, ldc = 1; for (int64_t nb = 0; nb < NB; ++nb) { int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(BLOCK_N, N - nb_start); switch (nb_size) { case 1: LAUNCH_TINYGEMM_KERNEL_NT(1, 1); break; case 2: LAUNCH_TINYGEMM_KERNEL_NT(1, 2); break; case 3: LAUNCH_TINYGEMM_KERNEL_NT(1, 3); break; case 4: LAUNCH_TINYGEMM_KERNEL_NT(1, 4); break; case 5: LAUNCH_TINYGEMM_KERNEL_NT(1, 5); break; case 6: LAUNCH_TINYGEMM_KERNEL_NT(1, 6); break; case 7: LAUNCH_TINYGEMM_KERNEL_NT(1, 7); break; case 8: LAUNCH_TINYGEMM_KERNEL_NT(1, 8); break; default: TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size"); } } return; } // pattern: 1-6-24 constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_N = 6; const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); for (int64_t mb = 0; mb < MB; ++mb) { int64_t mb_start = mb * BLOCK_M; int64_t mb_size = std::min(BLOCK_M, M - mb_start); for (int64_t nb = 0; nb < NB; ++nb) { int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(BLOCK_N, N - nb_start); switch (mb_size << 4 | nb_size) { // mb_size = 1 case 0x11: LAUNCH_TINYGEMM_KERNEL_NT(1, 1); break; case 0x12: LAUNCH_TINYGEMM_KERNEL_NT(1, 2); break; case 0x13: LAUNCH_TINYGEMM_KERNEL_NT(1, 3); break; case 0x14: LAUNCH_TINYGEMM_KERNEL_NT(1, 4); break; case 0x15: LAUNCH_TINYGEMM_KERNEL_NT(1, 5); break; case 0x16: LAUNCH_TINYGEMM_KERNEL_NT(1, 6); break; // mb_size = 2 case 0x21: LAUNCH_TINYGEMM_KERNEL_NT(2, 1); break; case 0x22: LAUNCH_TINYGEMM_KERNEL_NT(2, 2); break; case 0x23: LAUNCH_TINYGEMM_KERNEL_NT(2, 3); break; case 0x24: LAUNCH_TINYGEMM_KERNEL_NT(2, 4); break; case 0x25: LAUNCH_TINYGEMM_KERNEL_NT(2, 5); break; case 0x26: LAUNCH_TINYGEMM_KERNEL_NT(2, 6); break; // mb_size = 3 case 0x31: LAUNCH_TINYGEMM_KERNEL_NT(3, 1); break; case 0x32: LAUNCH_TINYGEMM_KERNEL_NT(3, 2); break; case 0x33: LAUNCH_TINYGEMM_KERNEL_NT(3, 3); break; case 0x34: LAUNCH_TINYGEMM_KERNEL_NT(3, 4); break; case 0x35: LAUNCH_TINYGEMM_KERNEL_NT(3, 5); break; case 0x36: LAUNCH_TINYGEMM_KERNEL_NT(3, 6); break; // mb_size = 4 case 0x41: LAUNCH_TINYGEMM_KERNEL_NT(4, 1); break; case 0x42: LAUNCH_TINYGEMM_KERNEL_NT(4, 2); break; case 0x43: LAUNCH_TINYGEMM_KERNEL_NT(4, 3); break; case 0x44: LAUNCH_TINYGEMM_KERNEL_NT(4, 4); break; case 0x45: LAUNCH_TINYGEMM_KERNEL_NT(4, 5); break; case 0x46: LAUNCH_TINYGEMM_KERNEL_NT(4, 6); break; default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); } } } } template void index_gemm_kernel_nn( const float* __restrict__ A, const scalar_t* __restrict__ B, float* __restrict__ C, const index_t* __restrict__ indices, float* __restrict__ scale, int64_t M, int64_t N, int64_t K, int64_t lda, int64_t ldb, int64_t ldc, int64_t max_tokens) { constexpr int kVecSize = 16; if ((N & (kVecSize - 1)) != 0) { tinygemm_kernel_nn_scalar(A, B, C, indices, scale, M, N, K, lda, ldb, ldc, max_tokens); return; } // pattern: 1-8-8 if (M == 1) { constexpr int64_t BLOCK_N = 8 * kVecSize; const int64_t NB = div_up(N, BLOCK_N); int64_t mb_start = 0, lda = 1, ldc = 1; for (int64_t nb = 0; nb < NB; ++nb) { int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(BLOCK_N, N - nb_start); switch (nb_size >> 4) { case 1: LAUNCH_TINYGEMM_KERNEL_NN(1, 16); break; case 2: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; case 3: LAUNCH_TINYGEMM_KERNEL_NN(1, 48); break; case 4: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; case 5: LAUNCH_TINYGEMM_KERNEL_NN(1, 80); break; case 6: LAUNCH_TINYGEMM_KERNEL_NN(1, 96); break; case 7: LAUNCH_TINYGEMM_KERNEL_NN(1, 112); break; case 8: LAUNCH_TINYGEMM_KERNEL_NN(1, 128); break; default: TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size"); } } return; } constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_N = 6 * kVecSize; const int64_t MB = div_up(M, BLOCK_M); const int64_t NB = div_up(N, BLOCK_N); for (int64_t mb = 0; mb < MB; ++mb) { int64_t mb_start = mb * BLOCK_M; int64_t mb_size = std::min(BLOCK_M, M - mb_start); for (int64_t nb = 0; nb < NB; ++nb) { int64_t nb_start = nb * BLOCK_N; int64_t nb_size = std::min(BLOCK_N, N - nb_start); switch (mb_size << 4 | nb_size >> 4) { // mb_size = 1 case 0x11: LAUNCH_TINYGEMM_KERNEL_NN(1, 16); break; case 0x12: LAUNCH_TINYGEMM_KERNEL_NN(1, 32); break; case 0x13: LAUNCH_TINYGEMM_KERNEL_NN(1, 48); break; case 0x14: LAUNCH_TINYGEMM_KERNEL_NN(1, 64); break; case 0x15: LAUNCH_TINYGEMM_KERNEL_NN(1, 80); break; case 0x16: LAUNCH_TINYGEMM_KERNEL_NN(1, 96); break; // mb_size = 2 case 0x21: LAUNCH_TINYGEMM_KERNEL_NN(2, 16); break; case 0x22: LAUNCH_TINYGEMM_KERNEL_NN(2, 32); break; case 0x23: LAUNCH_TINYGEMM_KERNEL_NN(2, 48); break; case 0x24: LAUNCH_TINYGEMM_KERNEL_NN(2, 64); break; case 0x25: LAUNCH_TINYGEMM_KERNEL_NN(2, 80); break; case 0x26: LAUNCH_TINYGEMM_KERNEL_NN(2, 96); break; // mb_size = 3 case 0x31: LAUNCH_TINYGEMM_KERNEL_NN(3, 16); break; case 0x32: LAUNCH_TINYGEMM_KERNEL_NN(3, 32); break; case 0x33: LAUNCH_TINYGEMM_KERNEL_NN(3, 48); break; case 0x34: LAUNCH_TINYGEMM_KERNEL_NN(3, 64); break; case 0x35: LAUNCH_TINYGEMM_KERNEL_NN(3, 80); break; case 0x36: LAUNCH_TINYGEMM_KERNEL_NN(3, 96); break; // mb_size = 4 case 0x41: LAUNCH_TINYGEMM_KERNEL_NN(4, 16); break; case 0x42: LAUNCH_TINYGEMM_KERNEL_NN(4, 32); break; case 0x43: LAUNCH_TINYGEMM_KERNEL_NN(4, 48); break; case 0x44: LAUNCH_TINYGEMM_KERNEL_NN(4, 64); break; case 0x45: LAUNCH_TINYGEMM_KERNEL_NN(4, 80); break; case 0x46: LAUNCH_TINYGEMM_KERNEL_NN(4, 96); break; default: TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); } } } } template void decode_attention_kernel_impl( scalar_t* __restrict__ output, float* __restrict__ attn_logits, const scalar_t* __restrict__ query, scalar_t* __restrict__ k_buffer, scalar_t* __restrict__ v_buffer, const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, const int64_t* __restrict__ loc, const index_t* __restrict__ req_to_token, const int64_t* __restrict__ req_pool_indices, const int64_t* __restrict__ seq_lens, int64_t batches, int64_t num_heads, int64_t head_size, int64_t head_size_v, int64_t num_kv_splits, int64_t k_strideN, int64_t k_strideH, int64_t v_strideN, int64_t v_strideH, int64_t nk_strideN, int64_t nk_strideH, int64_t nv_strideN, int64_t nv_strideH, float scaling, float logit_cap, int64_t max_num_reqs, int64_t max_context_len, int64_t max_total_num_tokens) { at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { int64_t bs{0}, head_id{0}; data_index_init(begin, bs, batches, head_id, num_heads); for (int64_t i = begin; i < end; i++) { int64_t loc_val = loc[bs]; scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_id * k_strideH; scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_id * v_strideH; const scalar_t* new_key_ptr = key + bs * nk_strideN + head_id * nk_strideH; const scalar_t* new_value_ptr = value + bs * nv_strideN + head_id * nv_strideH; copy_stub(k_buffer_ptr, new_key_ptr, head_size); copy_stub(v_buffer_ptr, new_value_ptr, head_size_v); // move to the next index data_index_step(bs, batches, head_id, num_heads); } }); using Vec = at::vec::Vectorized; // block length for k_buffer and v_buffer constexpr int64_t BLOCK_N = 256; // strides const int64_t q_strideM = num_heads * head_size; const int64_t q_strideH = head_size; const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride2 = head_size_v + 1; const bool has_logit_cap = logit_cap > 0; float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f; // parallel on [batches, num_heads, num_kv_splits] at::parallel_for(0, batches * num_heads * num_kv_splits, 0, [&](int64_t begin, int64_t end) { int64_t bs{0}, head_id{0}, kv_id{0}; data_index_init(begin, bs, batches, head_id, num_heads, kv_id, num_kv_splits); // s_prime and s_delta alignas(64) float s_i[BLOCK_N]; float* __restrict__ s_delta = s_i; for (int64_t i = begin; i < end; ++i) { // get query const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + head_id * q_strideH; // get key/value int64_t seq_len_kv = seq_lens[bs]; int64_t req_pool_id = req_pool_indices[bs]; TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!"); TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!"); const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits); const int64_t kv_start = kv_id * SPLIT_SIZE; const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv); float m_prime = -std::numeric_limits::infinity(); float s_prime = 0.f; // get v_prime, and init to zero float* __restrict__ v_prime = attn_logits + i * (head_size_v + 1); fill_stub(v_prime, 0.f, head_size_v); // loop over K and V sequence with BLOCK_N for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) { int64_t n_size = std::min(BLOCK_N, kv_end - n); // calculate s_i <- scale * Q @ K index_gemm_kernel_nt( /* A */ q_ptr, /* B */ k_buffer + head_id * k_strideH, /* C */ s_i, /* ind */ req_to_token + req_pool_id * max_context_len + n, /* scl */ scaling, /* M */ 1, /* N */ n_size, /* K */ head_size, /* lda */ 1, /* ldb */ k_strideN, /* ldc */ 1, /* mtt */ max_total_num_tokens); // TODO: `tanh` from torch uses sleef u10, going to be slow if (has_logit_cap) { at::vec::map( [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, s_i, s_i, n_size); } // m_i: max value per row float m_i = at::vec::reduce_all([](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i, n_size); m_i = std::max(m_i, m_prime); // m_delta <- exp(m' - m_i) float m_delta = std::exp(m_prime - m_i); // s_delta <- exp(s_i - m_i) at::vec::map([m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta, s_i, n_size); // s' <- s' * m_delta + sum(s_delta) s_prime *= m_delta; s_prime += at::vec::reduce_all([](Vec& x, Vec& y) { return x + y; }, s_delta, n_size); m_prime = m_i; // calculate V' <- s_delta @ V + V' * m_delta index_gemm_kernel_nn( /* A */ s_delta, /* B */ v_buffer + head_id * v_strideH, /* C */ v_prime, /* ind */ req_to_token + req_pool_id * max_context_len + n, /* scl */ &m_delta, /* M */ 1, /* N */ head_size_v, /* K */ n_size, /* lda */ 1, /* ldb */ v_strideN, /* ldc */ 1, /* mtt */ max_total_num_tokens); } // loop with KV blocks // only update v' when kv_split_size > 0 if (kv_end > kv_start) { float s = 1 / s_prime; at::vec::map([s](Vec out) { return out * Vec(s); }, v_prime, v_prime, head_size_v); v_prime[head_size_v] = m_prime + std::log(s_prime); } // move to the next index data_index_step(bs, batches, head_id, num_heads, kv_id, num_kv_splits); } }); // parallel on [batches, num_heads] at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { // NB: here we use logits[b][h][0] as acc, since // for the first kv split (kv_id == 0): // m_delta = std::exp(-inf) = 0 // e_logic = std::exp(0) = 1 // acc = acc * m_delta + tv * e_logic = tv for (int64_t i = begin; i < end; ++i) { float* __restrict__ acc = attn_logits + i * l_stride1; float s_prime = 0.f; float m_prime = -std::numeric_limits::infinity(); // update acc with from each kv_split for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { float* __restrict__ tv = acc + kv_id * l_stride2; const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; float m_i = std::max(tlogic, m_prime); float m_delta = std::exp(m_prime - m_i); float e_logic = std::exp(tlogic - m_i); if (kv_id != 0) { at::vec::map2( [m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, acc, acc, tv, head_size_v); } s_prime = s_prime * m_delta + e_logic; m_prime = m_i; } copy_stub(output + i * head_size_v, acc, 1 / s_prime, head_size_v); } }); } template void decode_attention_grouped_kernel_impl( scalar_t* __restrict__ output, float* __restrict__ attn_logits, const scalar_t* __restrict__ query, scalar_t* __restrict__ k_buffer, scalar_t* __restrict__ v_buffer, const scalar_t* __restrict__ key, const scalar_t* __restrict__ value, const int64_t* __restrict__ loc, const index_t* __restrict__ req_to_token, const int64_t* __restrict__ req_pool_indices, const int64_t* __restrict__ seq_lens, int64_t batches, int64_t num_heads, int64_t num_heads_kv, int64_t head_size, int64_t head_size_v, int64_t num_kv_splits, int64_t k_strideN, int64_t k_strideH, int64_t v_strideN, int64_t v_strideH, int64_t nk_strideN, int64_t nk_strideH, int64_t nv_strideN, int64_t nv_strideH, float scaling, float logit_cap, int64_t max_num_reqs, int64_t max_context_len, int64_t max_total_num_tokens) { at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) { int64_t bs{0}, head_kv_id{0}; data_index_init(begin, bs, batches, head_kv_id, num_heads_kv); for (int64_t i = begin; i < end; i++) { int64_t loc_val = loc[bs]; scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH; scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH; const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH; const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH; copy_stub(k_buffer_ptr, new_key_ptr, head_size); copy_stub(v_buffer_ptr, new_value_ptr, head_size_v); // move to the next index data_index_step(bs, batches, head_kv_id, num_heads_kv); } }); using Vec = at::vec::Vectorized; // block length for k_buffer and v_buffer constexpr int64_t BLOCK_N = 256; // block length for heads // we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits] // use smaller BLOCK_H when batches is small to utilize all cores constexpr int64_t kBLOCK_H = 16; const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H); // strides const int64_t q_strideM = num_heads * head_size; const int64_t q_strideH = head_size; const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1); const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride2 = head_size_v + 1; const bool has_logit_cap = logit_cap > 0; float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f; // partition the heads into blocks for parallel const int64_t num_groups = num_heads / num_heads_kv; const int64_t num_blocks = div_up(num_groups, BLOCK_H); // parallel on [batches, num_heads_kv, num_blocks, num_kv_splits] at::parallel_for(0, batches * num_heads_kv * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) { int64_t bs{0}, head_kv_id{0}, block_id{0}, kv_id{0}; data_index_init(begin, bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits); alignas(64) float s_i[BLOCK_H * BLOCK_N]; float* __restrict__ s_delta = s_i; alignas(64) float s_prime[BLOCK_H]; alignas(64) float m_prime[BLOCK_H]; alignas(64) float m_delta[BLOCK_H]; for (int64_t i = begin; i < end; ++i) { const int64_t h_start = head_kv_id * num_groups + block_id * BLOCK_H; const int64_t h_end = head_kv_id * num_groups + std::min(block_id * BLOCK_H + BLOCK_H, num_groups); const int64_t h_size = h_end - h_start; // get query const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH; int64_t seq_len_kv = seq_lens[bs]; int64_t req_pool_id = req_pool_indices[bs]; TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!"); TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!"); const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits); const int64_t kv_start = kv_id * SPLIT_SIZE; const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv); fill_stub(s_prime, 0.f, BLOCK_H); fill_stub(m_prime, -std::numeric_limits::infinity(), BLOCK_H); // get v_prime, and init to zero float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2; for (int64_t h = 0; h < h_size; ++h) { fill_stub(v_prime + h * l_stride1, 0.f, head_size_v); } // loop over K and V sequence with BLOCK_N for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) { int64_t n_size = std::min(BLOCK_N, kv_end - n); // calculate Q @ K index_gemm_kernel_nt( /* A */ q_ptr, /* B */ k_buffer + head_kv_id * k_strideH, /* C */ s_i, /* ind */ req_to_token + req_pool_id * max_context_len + n, /* scl */ scaling, /* M */ h_size, /* N */ n_size, /* K */ head_size, /* lda */ q_strideH, /* ldb */ k_strideN, /* ldc */ BLOCK_N, /* mtt */ max_total_num_tokens); if (has_logit_cap) { at::vec::map( [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, s_i, s_i, n_size); } // update the scaling coefficients for (int64_t h = 0; h < h_size; ++h) { // m_i: max value per row float m_i = at::vec::reduce_all( [](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size); m_i = std::max(m_i, m_prime[h]); // m_delta <- exp(m' - m_i) m_delta[h] = std::exp(m_prime[h] - m_i); // s_delta <- exp(s_i - m_i) at::vec::map( [m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size); // s' <- s' * m_delta + sum(s_delta) s_prime[h] *= m_delta[h]; s_prime[h] += at::vec::reduce_all([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size); m_prime[h] = m_i; } // calculate V' <- s_delta @ V + V' * m_delta index_gemm_kernel_nn( /* A */ s_delta, /* B */ v_buffer + head_kv_id * v_strideH, /* C */ v_prime, /* ind */ req_to_token + req_pool_id * max_context_len + n, /* scl */ m_delta, /* M */ h_size, /* N */ head_size_v, /* K */ n_size, /* lda */ BLOCK_N, /* ldb */ v_strideN, /* ldc */ l_stride1, /* mtt */ max_total_num_tokens); } // loop with KV blocks // only update v' when kv_split_size > 0 if (kv_end > kv_start) { for (int64_t h = 0; h < h_size; ++h) { float s = 1 / s_prime[h]; at::vec::map( [s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v); (v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]); } } // move to the next index data_index_step(bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits); } }); // parallel on [batches, num_heads] at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { // NB: same as above for (int64_t i = begin; i < end; ++i) { float* __restrict__ acc = attn_logits + i * l_stride1; float s_prime = 0.f; float m_prime = -std::numeric_limits::infinity(); // update acc with from each kv_split for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { float* __restrict__ tv = acc + kv_id * l_stride2; const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; float m_i = std::max(tlogic, m_prime); float m_delta = std::exp(m_prime - m_i); float e_logic = std::exp(tlogic - m_i); if (kv_id != 0) { at::vec::map2( [m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, acc, acc, tv, head_size_v); } s_prime = s_prime * m_delta + e_logic; m_prime = m_i; } copy_stub(output + i * head_size_v, acc, 1 / s_prime, head_size_v); } }); } } // anonymous namespace // query: [num_tokens, num_heads, head_size] // output: [num_tokens, num_heads, head_size] // k_buffer: [max_total_num_tokens, num_heads, head_size] // v_buffer: [max_total_num_tokens, num_heads, head_size_v] // attn_logits: [num_seqs, num_heads, num_kv_splits, head_size_v + 1] // req_to_token: [max_num_reqs, max_context_len] int32 or int64 // req_pool_indices: [num_seqs] int64 // seq_lens: [num_seqs] int64 // void decode_attention_cpu( at::Tensor& query, at::Tensor& k_buffer, at::Tensor& v_buffer, at::Tensor& output, at::Tensor& key, at::Tensor& value, at::Tensor& loc, at::Tensor& attn_logits, at::Tensor& req_to_token, at::Tensor& req_pool_indices, at::Tensor& seq_lens, double sm_scale, double logit_cap) { RECORD_FUNCTION( "sgl-kernel::decode_attention_cpu", std::vector( {query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens})); CHECK_INPUT(query); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer); CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer); // for MLA, key and value shares the same storage and value could be non-contiguous CHECK_LAST_DIM_CONTIGUOUS_INPUT(key); CHECK_LAST_DIM_CONTIGUOUS_INPUT(value); CHECK_DIM(3, query); CHECK_DIM(3, k_buffer); CHECK_DIM(3, v_buffer); CHECK_DIM(3, key); CHECK_DIM(3, value); CHECK_DIM(1, loc); int64_t num_seqs = seq_lens.size(0); int64_t max_num_reqs = req_to_token.size(0); int64_t max_context_len = req_to_token.size(1); int64_t max_total_num_tokens = k_buffer.size(0); int64_t num_heads = query.size(1); int64_t num_heads_kv = k_buffer.size(1); int64_t head_size = query.size(2); int64_t head_size_v = v_buffer.size(2); int64_t num_kv_splits = attn_logits.size(2); CHECK_EQ(loc.numel(), num_seqs); CHECK_EQ(attn_logits.size(0), num_seqs); CHECK_EQ(attn_logits.size(1), num_heads); CHECK_EQ(attn_logits.size(3), head_size_v + 1); CHECK_EQ(attn_logits.scalar_type(), at::kFloat); // strides for k_buffer and v_buffer int64_t k_strideN = k_buffer.stride(0); int64_t k_strideH = k_buffer.stride(1); int64_t v_strideN = v_buffer.stride(0); int64_t v_strideH = v_buffer.stride(1); // strides for new key and value int64_t nk_strideN = key.stride(0); int64_t nk_strideH = key.stride(1); int64_t nv_strideN = value.stride(0); int64_t nv_strideH = value.stride(1); // check index data types const auto index_dtype = req_to_token.scalar_type(); TORCH_CHECK( index_dtype == at::kInt || index_dtype == at::kLong, "decode: expect req_to_token to be int32 or int64, got ", index_dtype); TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "decode: expect req_lens to be int64, got ", seq_lens.scalar_type()); TORCH_CHECK( req_pool_indices.scalar_type() == at::kLong, "decode: expect req_pool_indices to be int64, got ", req_pool_indices.scalar_type()); AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] { AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] { if (num_heads == num_heads_kv) { // MHA decode_attention_kernel_impl( output.data_ptr(), attn_logits.data_ptr(), query.data_ptr(), k_buffer.data_ptr(), v_buffer.data_ptr(), key.data_ptr(), value.data_ptr(), loc.data_ptr(), req_to_token.data_ptr(), req_pool_indices.data_ptr(), seq_lens.data_ptr(), num_seqs, num_heads, head_size, head_size_v, num_kv_splits, k_strideN, k_strideH, v_strideN, v_strideH, nk_strideN, nv_strideH, nv_strideN, nv_strideH, sm_scale, logit_cap, max_num_reqs, max_context_len, max_total_num_tokens); } else { // GQA/MQA/MLA decode_attention_grouped_kernel_impl( output.data_ptr(), attn_logits.data_ptr(), query.data_ptr(), k_buffer.data_ptr(), v_buffer.data_ptr(), key.data_ptr(), value.data_ptr(), loc.data_ptr(), req_to_token.data_ptr(), req_pool_indices.data_ptr(), seq_lens.data_ptr(), num_seqs, num_heads, num_heads_kv, head_size, head_size_v, num_kv_splits, k_strideN, k_strideH, v_strideN, v_strideH, nk_strideN, nk_strideH, nv_strideN, nv_strideH, sm_scale, logit_cap, max_num_reqs, max_context_len, max_total_num_tokens); } }); }); }