Unverified Commit fcde67b0 authored by YanbingJiang's avatar YanbingJiang Committed by GitHub
Browse files

CPU: map changes from developing branch in sgl-kernel (#6833)


Co-authored-by: default avatarmingfeima <mingfei.ma@intel.com>
parent 81372f3b
#include "common.h" #include "common.h"
#include "gemm.h"
#include "vec.h" #include "vec.h"
namespace { namespace {
...@@ -11,19 +12,144 @@ namespace { ...@@ -11,19 +12,144 @@ namespace {
// 4. provide amx kernel for index_gemm_kernel_nn when M = 16 // 4. provide amx kernel for index_gemm_kernel_nn when M = 16
// //
inline void fill_stub(float* __restrict__ out, float val, int64_t size) { #if defined(CPU_CAPABILITY_AVX512)
using Vec = at::vec::Vectorized<float>; // key: from [N, 32] to [32/2, N, 2]
const Vec data_vec(val); // val: from [N, 32] to [N/2, 32, 2]
at::vec::map<float>([data_vec](Vec out) { return out = data_vec; }, out, out, size); template <typename scalar_t, typename index_t>
inline void pack_vnni_Nx32(
scalar_t* __restrict__ dst0,
scalar_t* __restrict__ dst1,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int ld_src,
int ld_dst0,
int ld_dst1,
bool convert_v) {
__m512i vinputs[16];
int n = 0;
for (; n < N; ++n) {
vinputs[n] = _mm512_loadu_si512(src + ind[n] * ld_src);
}
// padding with zero to avoid uninitialized vectors
for (; n < 16; ++n) {
vinputs[n] = _mm512_set1_epi32(0);
}
// pack value, skip 64 elems for deepseek
// handle 2 vectors at a time from [2, 32] to [32, 2]
if (convert_v) {
for (int n = 0; n < 16; n += 2) {
__m512i d0, d1;
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[n], vinputs[n + 1]);
_mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2, d0);
_mm512_storeu_si512(dst1 + (n >> 1) * ld_dst1 * 2 + 32, d1);
}
}
// pack key
transpose_16x16_32bit(vinputs);
const __mmask16 vmask = (1 << N) - 1;
for (int k = 0; k < 16; ++k) {
_mm512_mask_storeu_epi32(dst0 + k * ld_dst0 * 2, vmask, vinputs[k]);
}
}
#endif
// [NOTE]: MLA vnni format conversion
//
// here we apply same strategy as `FlashMLA`:
// each kv_cache is loaded once and packed twice (L2 cache hit)
//
// * for key: from [N, K/2, 2] to [K/2, N, 2]
// * for value: from [N/2, 2, Kv] to [N/2, Kv, 2]
//
template <typename scalar_t, typename index_t>
void pack_vnni(
scalar_t* __restrict__ dst0,
scalar_t* __restrict__ dst1,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int K,
int Kv,
int ld_src,
int ld_dst0,
int ld_dst1) {
#if defined(CPU_CAPABILITY_AVX512)
const int NB = div_up(N, 16);
const int KB = K / 32; // no remainder
const int KBv = Kv / 32; // no remainder
for (int nb = 0; nb < NB; ++nb) {
for (int kb = 0; kb < KB; ++kb) {
// handle 16x512bits each block
int nb_size = std::min(N - nb * 16, 16);
pack_vnni_Nx32<scalar_t, index_t>(
/* dst0 */ dst0 + ((kb * 32) >> 1) * ld_dst0 * 2 + nb * 16 * 2,
/* dst1 */ dst1 + ((nb * 16) >> 1) * ld_dst1 * 2 + kb * 32 * 2,
/* src */ src + kb * 32,
/* ind */ ind + nb * 16,
/* N */ nb_size,
/* ld_src */ ld_src,
/* ld_dst0 */ ld_dst0,
/* ld_dst1 */ ld_dst1,
/* cvt_v */ kb < KBv);
}
}
#else
for (int n = 0; n < N; ++n) {
index_t index = ind[n];
for (int k = 0; k < K / 2; ++k) {
for (int d = 0; d < 2; ++d) {
dst0[k * ld_dst0 * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
}
}
}
// from [N/2, 2, K] to [N/2, K, 2]
for (int n = 0; n < (N >> 1) * 2; n += 2) {
index_t index0 = ind[n + 0];
index_t index1 = ind[n + 1];
for (int k = 0; k < Kv; ++k) {
dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index0 * ld_src + k];
dst1[(n >> 1) * ld_dst1 * 2 + k * 2 + 1] = src[index1 * ld_src + k];
}
}
if (N % 2 != 0) {
index_t index = ind[N - 1];
for (int k = 0; k < Kv; ++k) {
dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 0] = src[index * ld_src + k];
dst1[(N >> 1) * ld_dst1 * 2 + k * 2 + 1] = 0;
}
}
#endif
}
template <typename scalar_t>
inline void fill_stub(scalar_t* __restrict__ out, float val, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
constexpr int kVecSize = Vec::size();
const Vec data_vec = Vec(static_cast<scalar_t>(val));
int64_t d = 0;
#pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
data_vec.store(out + d);
}
if (size - d > 0) {
data_vec.store(out + d, size - d);
}
} }
template <typename scalar_t> template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) { inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>; using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>; using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_fvec = fVec(s); const fVec s_fvec = fVec(s);
int64_t d = 0; int64_t d = 0;
for (; d <= size - bVec::size(); d += bVec::size()) { #pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec; fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec; fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1); bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
...@@ -37,8 +163,10 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, ...@@ -37,8 +163,10 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc,
template <typename scalar_t> template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) { inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>; using bVec = at::vec::Vectorized<scalar_t>;
constexpr int kVecSize = bVec::size();
int64_t d = 0; int64_t d = 0;
for (; d <= size - bVec::size(); d += bVec::size()) { #pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
bVec out_bvec = bVec::loadu(src + d); bVec out_bvec = bVec::loadu(src + d);
out_bvec.store(out + d); out_bvec.store(out + d);
} }
...@@ -47,6 +175,26 @@ inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ s ...@@ -47,6 +175,26 @@ inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ s
} }
} }
template <typename scalar_t, int BLOCK_N>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
static_assert(BLOCK_N % 32 == 0);
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int COLS = BLOCK_N / 16;
auto store = [&](auto i) {
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
fVec a_fvec0 = fVec::loadu(input + col * 16);
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + col * 16);
}
};
Unroll<COLS>{}(store);
}
// GEMM handles query @ key (indexed) x scale // GEMM handles query @ key (indexed) x scale
// A : [M, K] // A : [M, K]
// B : [N, K] indexed // B : [N, K] indexed
...@@ -619,24 +767,17 @@ void index_gemm_kernel_nn( ...@@ -619,24 +767,17 @@ void index_gemm_kernel_nn(
} }
} }
template <typename scalar_t, typename index_t> template <typename scalar_t>
void decode_attention_kernel_impl( void decode_set_kv_buffer(
scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
const scalar_t* __restrict__ query,
scalar_t* __restrict__ k_buffer, scalar_t* __restrict__ k_buffer,
scalar_t* __restrict__ v_buffer, scalar_t* __restrict__ v_buffer,
const scalar_t* __restrict__ key, const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value, const scalar_t* __restrict__ value,
const int64_t* __restrict__ loc, const int64_t* __restrict__ loc,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
int64_t batches, int64_t batches,
int64_t num_heads, int64_t num_heads_kv,
int64_t head_size, int64_t head_size,
int64_t head_size_v, int64_t head_size_v,
int64_t num_kv_splits,
int64_t k_strideN, int64_t k_strideN,
int64_t k_strideH, int64_t k_strideH,
int64_t v_strideN, int64_t v_strideN,
...@@ -645,33 +786,104 @@ void decode_attention_kernel_impl( ...@@ -645,33 +786,104 @@ void decode_attention_kernel_impl(
int64_t nk_strideH, int64_t nk_strideH,
int64_t nv_strideN, int64_t nv_strideN,
int64_t nv_strideH, int64_t nv_strideH,
float scaling, bool is_mla) {
float logit_cap, at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) {
int64_t max_num_reqs, int64_t bs{0}, head_kv_id{0};
int64_t max_context_len, data_index_init(begin, bs, batches, head_kv_id, num_heads_kv);
int64_t max_total_num_tokens) {
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_id{0};
data_index_init(begin, bs, batches, head_id, num_heads);
for (int64_t i = begin; i < end; i++) { for (int64_t i = begin; i < end; i++) {
int64_t loc_val = loc[bs]; int64_t loc_val = loc[bs];
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_id * k_strideH; scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH;
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_id * v_strideH; const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH;
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_id * nk_strideH;
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_id * nv_strideH;
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size); copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v); if (!is_mla) {
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH;
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH;
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
}
// move to the next index // move to the next index
data_index_step(bs, batches, head_id, num_heads); data_index_step(bs, batches, head_kv_id, num_heads_kv);
} }
}); });
}
template <typename scalar_t>
void decode_accumulate_kv_splits(
scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
int64_t batches,
int64_t num_heads,
int64_t head_size_v,
int64_t num_kv_splits,
int64_t l_stride1,
int64_t l_stride2) {
using Vec = at::vec::Vectorized<float>; using Vec = at::vec::Vectorized<float>;
// block length for k_buffer and v_buffer // parallel on [batches, num_heads]
constexpr int64_t BLOCK_N = 256; at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
// NB: here we use logits[b][h][0] as acc, since
// for the first kv split (kv_id == 0):
// m_delta = std::exp(-inf) = 0
// e_logic = std::exp(0) = 1
// acc = acc * m_delta + tv * e_logic = tv
for (int64_t i = begin; i < end; ++i) {
float* __restrict__ acc = attn_logits + i * l_stride1;
float s_prime = 0.f;
float m_prime = -std::numeric_limits<scalar_t>::infinity();
// update acc with from each kv_split
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
float* __restrict__ tv = acc + kv_id * l_stride2;
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
float m_i = std::max(tlogic, m_prime);
float m_delta = std::exp(m_prime - m_i);
float e_logic = std::exp(tlogic - m_i);
if (kv_id != 0) {
at::vec::map2<float>(
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
acc,
acc,
tv,
head_size_v);
}
s_prime = s_prime * m_delta + e_logic;
m_prime = m_i;
}
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
}
});
}
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
void decode_attention_kernel_impl(
scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
const scalar_t* __restrict__ query,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
int64_t batches,
int64_t num_heads,
int64_t head_size,
int64_t head_size_v,
int64_t num_kv_splits,
int64_t k_strideN,
int64_t k_strideH,
int64_t v_strideN,
int64_t v_strideH,
float scaling,
float logit_cap,
int64_t max_num_reqs,
int64_t max_context_len,
int64_t max_total_num_tokens) {
using Vec = at::vec::Vectorized<float>;
// strides // strides
const int64_t q_strideM = num_heads * head_size; const int64_t q_strideM = num_heads * head_size;
...@@ -785,55 +997,209 @@ void decode_attention_kernel_impl( ...@@ -785,55 +997,209 @@ void decode_attention_kernel_impl(
} }
}); });
// parallel on [batches, num_heads] decode_accumulate_kv_splits(
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
// NB: here we use logits[b][h][0] as acc, since } // MHA
// for the first kv split (kv_id == 0):
// m_delta = std::exp(-inf) = 0 template <typename scalar_t, typename index_t, int64_t BLOCK_N>
// e_logic = std::exp(0) = 1 void decode_attention_mla_kernel_impl(
// acc = acc * m_delta + tv * e_logic = tv scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
const scalar_t* __restrict__ query,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
scalar_t* __restrict__ buffer,
int64_t batches,
int64_t num_heads,
int64_t head_size,
int64_t head_size_v,
int64_t num_kv_splits,
int64_t k_strideN,
int64_t k_strideH,
int64_t v_strideN,
int64_t v_strideH,
float scaling,
float logit_cap,
int64_t max_num_reqs,
int64_t max_context_len,
int64_t max_total_num_tokens,
int64_t buffer_size_per_thread) {
using Vec = at::vec::Vectorized<float>;
// block length for heads
const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11);
// strides
const int64_t q_strideM = num_heads * head_size;
const int64_t q_strideH = head_size;
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
const int64_t l_stride2 = head_size_v + 1;
TORCH_CHECK(logit_cap == 0.f, "decode MLA: expect no logit_cap.");
// partition the heads into blocks for parallel
const int64_t num_blocks = div_up(num_heads, BLOCK_H);
// parallel on [batches, num_blocks, num_kv_splits]
at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, block_id{0}, kv_id{0};
data_index_init(begin, bs, batches, block_id, num_blocks, kv_id, num_kv_splits);
int tid = at::get_thread_num();
scalar_t* __restrict__ Btmp0 = buffer + tid * buffer_size_per_thread;
scalar_t* __restrict__ Btmp1 = Btmp0 + BLOCK_N * head_size;
// init Btmp1 just once for each thread to prevent NaN
// Btmp0 is not needed as it computes full K every single time
fill_stub(Btmp1, 0.f, BLOCK_N * head_size_v);
alignas(64) float s_i[BLOCK_H * BLOCK_N];
float* __restrict__ s_delta = s_i;
alignas(64) scalar_t s_delta2[BLOCK_H * BLOCK_N];
alignas(64) float s_prime[BLOCK_H];
alignas(64) float m_prime[BLOCK_H];
alignas(64) float m_delta[BLOCK_H];
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
float* __restrict__ acc = attn_logits + i * l_stride1; const int64_t h_start = block_id * BLOCK_H;
const int64_t h_end = std::min(block_id * BLOCK_H + BLOCK_H, num_heads);
const int64_t h_size = h_end - h_start;
float s_prime = 0.f; // get query
float m_prime = -std::numeric_limits<scalar_t>::infinity(); const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
// update acc with from each kv_split int64_t seq_len_kv = seq_lens[bs];
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) { int64_t req_pool_id = req_pool_indices[bs];
float* __restrict__ tv = acc + kv_id * l_stride2; TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
const float tlogic = (acc + kv_id * l_stride2)[head_size_v]; TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
float m_i = std::max(tlogic, m_prime); const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits);
float m_delta = std::exp(m_prime - m_i); const int64_t kv_start = kv_id * SPLIT_SIZE;
float e_logic = std::exp(tlogic - m_i); const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv);
if (kv_id != 0) {
at::vec::map2<float>( fill_stub(s_prime, 0.f, BLOCK_H);
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); }, fill_stub(m_prime, -std::numeric_limits<float>::infinity(), BLOCK_H);
acc,
acc, // get v_prime, and init to zero
tv, float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2;
for (int64_t h = 0; h < h_size; ++h) {
fill_stub(v_prime + h * l_stride1, 0.f, head_size_v);
}
// loop over K and V sequence with BLOCK_N
for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) {
int64_t n_size = std::min(BLOCK_N, kv_end - n);
const int64_t padded_n_size = div_up(int(n_size), TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst0 */ Btmp0,
/* dst1 */ Btmp1,
/* src */ k_buffer + /* head_kv_id */ 0 * k_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* N */ n_size,
/* K */ head_size,
/* Kv */ head_size_v,
/* ld_src */ k_strideN,
/* ld_dst0 */ BLOCK_N,
/* ld_dst1 */ head_size_v);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ h_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideH,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp0,
/* C */ s_i);
const Vec scale_vec = Vec(scaling);
for (int64_t h = 0; h < h_size; ++h) {
// s_i <- s_i * scale
at::vec::map<float>(
[scale_vec](Vec x) { return x * scale_vec; }, s_i + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[h]);
// m_delta <- exp(m' - m_i)
m_delta[h] = std::exp(m_prime[h] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[h] *= m_delta[h];
s_prime[h] += at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size);
m_prime[h] = m_i;
// v' <- v' * m_delta
float scale_m = m_delta[h];
at::vec::map<float>(
[scale_m](Vec x) { return x * Vec(scale_m); },
v_prime + h * l_stride1,
v_prime + h * l_stride1,
head_size_v); head_size_v);
// pad s_delta with 0 first and then convert to scalar_t
fill_stub(s_delta + h * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
copy_stub<scalar_t, BLOCK_N>(s_delta2 + h * BLOCK_N, s_delta + h * BLOCK_N);
} }
s_prime = s_prime * m_delta + e_logic; // calculate V' <- s_delta @ V + V'
m_prime = m_i; at::native::cpublas::brgemm(
/* M */ h_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ l_stride1,
/* add_C */ true,
/* A */ s_delta2,
/* B */ Btmp1,
/* C */ v_prime);
} // loop with KV blocks
// only update v' when kv_split_size > 0
if (kv_end > kv_start) {
for (int64_t h = 0; h < h_size; ++h) {
float s = 1 / s_prime[h];
at::vec::map<float>(
[s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v);
(v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]);
}
} }
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v); // move to the next index
data_index_step(bs, batches, block_id, num_blocks, kv_id, num_kv_splits);
} }
at::native::cpublas::brgemm_release();
}); });
}
template <typename scalar_t, typename index_t> decode_accumulate_kv_splits(
output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
} // MLA
template <typename scalar_t, typename index_t, int64_t BLOCK_N>
void decode_attention_grouped_kernel_impl( void decode_attention_grouped_kernel_impl(
scalar_t* __restrict__ output, scalar_t* __restrict__ output,
float* __restrict__ attn_logits, float* __restrict__ attn_logits,
const scalar_t* __restrict__ query, const scalar_t* __restrict__ query,
scalar_t* __restrict__ k_buffer, const scalar_t* __restrict__ k_buffer,
scalar_t* __restrict__ v_buffer, const scalar_t* __restrict__ v_buffer,
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
const int64_t* __restrict__ loc,
const index_t* __restrict__ req_to_token, const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices, const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens, const int64_t* __restrict__ seq_lens,
...@@ -847,37 +1213,13 @@ void decode_attention_grouped_kernel_impl( ...@@ -847,37 +1213,13 @@ void decode_attention_grouped_kernel_impl(
int64_t k_strideH, int64_t k_strideH,
int64_t v_strideN, int64_t v_strideN,
int64_t v_strideH, int64_t v_strideH,
int64_t nk_strideN,
int64_t nk_strideH,
int64_t nv_strideN,
int64_t nv_strideH,
float scaling, float scaling,
float logit_cap, float logit_cap,
int64_t max_num_reqs, int64_t max_num_reqs,
int64_t max_context_len, int64_t max_context_len,
int64_t max_total_num_tokens) { int64_t max_total_num_tokens) {
at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_kv_id{0};
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv);
for (int64_t i = begin; i < end; i++) {
int64_t loc_val = loc[bs];
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH;
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH;
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH;
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH;
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
// move to the next index
data_index_step(bs, batches, head_kv_id, num_heads_kv);
}
});
using Vec = at::vec::Vectorized<float>; using Vec = at::vec::Vectorized<float>;
// block length for k_buffer and v_buffer
constexpr int64_t BLOCK_N = 256;
// block length for heads // block length for heads
// we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits] // we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits]
// use smaller BLOCK_H when batches is small to utilize all cores // use smaller BLOCK_H when batches is small to utilize all cores
...@@ -960,7 +1302,7 @@ void decode_attention_grouped_kernel_impl( ...@@ -960,7 +1302,7 @@ void decode_attention_grouped_kernel_impl(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); }, [logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i, s_i,
s_i, s_i,
n_size); BLOCK_H * BLOCK_N);
} }
// update the scaling coefficients // update the scaling coefficients
...@@ -1015,40 +1357,9 @@ void decode_attention_grouped_kernel_impl( ...@@ -1015,40 +1357,9 @@ void decode_attention_grouped_kernel_impl(
} }
}); });
// parallel on [batches, num_heads] decode_accumulate_kv_splits(
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { output, attn_logits, batches, num_heads, head_size_v, num_kv_splits, l_stride1, l_stride2);
// NB: same as above } // GQA/MQA
for (int64_t i = begin; i < end; ++i) {
float* __restrict__ acc = attn_logits + i * l_stride1;
float s_prime = 0.f;
float m_prime = -std::numeric_limits<scalar_t>::infinity();
// update acc with from each kv_split
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
float* __restrict__ tv = acc + kv_id * l_stride2;
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
float m_i = std::max(tlogic, m_prime);
float m_delta = std::exp(m_prime - m_i);
float e_logic = std::exp(tlogic - m_i);
if (kv_id != 0) {
at::vec::map2<float>(
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
acc,
acc,
tv,
head_size_v);
}
s_prime = s_prime * m_delta + e_logic;
m_prime = m_i;
}
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
}
});
}
} // anonymous namespace } // anonymous namespace
...@@ -1134,19 +1445,50 @@ void decode_attention_cpu( ...@@ -1134,19 +1445,50 @@ void decode_attention_cpu(
"decode: expect req_pool_indices to be int64, got ", "decode: expect req_pool_indices to be int64, got ",
req_pool_indices.scalar_type()); req_pool_indices.scalar_type());
// check if we have MLA here
void* k_buffer_data = k_buffer.data_ptr();
void* v_buffer_data = v_buffer.data_ptr();
const bool is_mla = (k_buffer_data == v_buffer_data) && (num_heads_kv == 1) && (head_size == head_size_v + 64);
// block length for k_buffer and v_buffer
constexpr int BLOCK_N = 256;
// buffer for packing k_cache and v_cache
int num_threads = at::get_num_threads();
int64_t size_per_thread = is_mla ? BLOCK_N * head_size + BLOCK_N * head_size_v : 0;
auto buffer = at::empty({num_threads, size_per_thread}, k_buffer.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] { AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] {
AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] { AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] {
// update the kv buffer
decode_set_kv_buffer(
(scalar_t*)k_buffer_data,
(scalar_t*)v_buffer_data,
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
loc.data_ptr<int64_t>(),
num_seqs,
num_heads_kv,
head_size,
head_size_v,
k_strideN,
k_strideH,
v_strideN,
v_strideH,
nk_strideN,
nk_strideH,
nv_strideN,
nv_strideH,
is_mla);
if (num_heads == num_heads_kv) { if (num_heads == num_heads_kv) {
// MHA // MHA
decode_attention_kernel_impl<scalar_t, index_t>( decode_attention_kernel_impl<scalar_t, index_t, BLOCK_N>(
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
attn_logits.data_ptr<float>(), attn_logits.data_ptr<float>(),
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(), (const scalar_t*)k_buffer_data,
v_buffer.data_ptr<scalar_t>(), (const scalar_t*)v_buffer_data,
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
loc.data_ptr<int64_t>(),
req_to_token.data_ptr<index_t>(), req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(), req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(), seq_lens.data_ptr<int64_t>(),
...@@ -1159,26 +1501,46 @@ void decode_attention_cpu( ...@@ -1159,26 +1501,46 @@ void decode_attention_cpu(
k_strideH, k_strideH,
v_strideN, v_strideN,
v_strideH, v_strideH,
nk_strideN,
nv_strideH,
nv_strideN,
nv_strideH,
sm_scale, sm_scale,
logit_cap, logit_cap,
max_num_reqs, max_num_reqs,
max_context_len, max_context_len,
max_total_num_tokens); max_total_num_tokens);
} else if (is_mla) {
// MLA
decode_attention_mla_kernel_impl<scalar_t, index_t, BLOCK_N>(
output.data_ptr<scalar_t>(),
attn_logits.data_ptr<float>(),
query.data_ptr<scalar_t>(),
(const scalar_t*)k_buffer_data,
(const scalar_t*)v_buffer_data,
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
buffer.data_ptr<scalar_t>(),
num_seqs,
num_heads,
head_size,
head_size_v,
num_kv_splits,
k_strideN,
k_strideH,
v_strideN,
v_strideH,
sm_scale,
logit_cap,
max_num_reqs,
max_context_len,
max_total_num_tokens,
size_per_thread);
} else { } else {
// GQA/MQA/MLA // GQA/MQA
decode_attention_grouped_kernel_impl<scalar_t, index_t>( decode_attention_grouped_kernel_impl<scalar_t, index_t, BLOCK_N>(
output.data_ptr<scalar_t>(), output.data_ptr<scalar_t>(),
attn_logits.data_ptr<float>(), attn_logits.data_ptr<float>(),
query.data_ptr<scalar_t>(), query.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(), (const scalar_t*)k_buffer_data,
v_buffer.data_ptr<scalar_t>(), (const scalar_t*)v_buffer_data,
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
loc.data_ptr<int64_t>(),
req_to_token.data_ptr<index_t>(), req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(), req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(), seq_lens.data_ptr<int64_t>(),
...@@ -1192,10 +1554,6 @@ void decode_attention_cpu( ...@@ -1192,10 +1554,6 @@ void decode_attention_cpu(
k_strideH, k_strideH,
v_strideN, v_strideN,
v_strideH, v_strideH,
nk_strideN,
nk_strideH,
nv_strideN,
nv_strideH,
sm_scale, sm_scale,
logit_cap, logit_cap,
max_num_reqs, max_num_reqs,
......
...@@ -10,11 +10,72 @@ namespace { ...@@ -10,11 +10,72 @@ namespace {
// 3. computes attention for prefix and extend separately // 3. computes attention for prefix and extend separately
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2` // 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
// //
template <typename index_t> template <typename index_t>
inline index_t get_index(index_t* ind, int i) { inline index_t get_index(index_t* ind, int i) {
return (ind == nullptr) ? (index_t)i : ind[i]; return (ind == nullptr) ? (index_t)i : ind[i];
} }
#if defined(CPU_CAPABILITY_AVX512)
// key: from [N, 32] to [32/2, N, 2]
template <typename scalar_t, typename index_t>
inline void pack_vnni_Nx32(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int ld_src,
int ld_dst) {
__m512i vinputs[16];
int n = 0;
for (; n < N; ++n) {
index_t index = get_index(ind, n);
vinputs[n] = _mm512_loadu_si512(src + index * ld_src);
}
// padding with zero to avoid uninitialized vectors
for (; n < 16; ++n) {
vinputs[n] = _mm512_set1_epi32(0);
}
// pack key
transpose_16x16_32bit(vinputs);
const __mmask16 vmask = (1 << N) - 1;
for (int k = 0; k < 16; ++k) {
_mm512_mask_storeu_epi32(dst + k * ld_dst * 2, vmask, vinputs[k]);
}
}
// value: from [K, 32] to [K/2, 32, 2]
template <typename scalar_t, typename index_t>
inline void pack_vnni_Kx32(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int K,
int ld_src,
int ld_dst) {
__m512i vinputs[2];
int k = 0;
for (; k < K; ++k) {
index_t index = get_index(ind, k);
vinputs[k] = _mm512_loadu_si512(src + index * ld_src);
}
// padding with zero to avoid uninitialized vectors
for (; k < 2; ++k) {
vinputs[k] = _mm512_set1_epi32(0);
}
// pack value
__m512i d0, d1;
std::tie(d0, d1) = transpose_2x32_16bit(vinputs[0], vinputs[1]);
_mm512_storeu_si512(dst + 0 * ld_dst * 2, d0);
_mm512_storeu_si512(dst + 0 * ld_dst * 2 + 32, d1);
}
#endif
// convert to vnni format // convert to vnni format
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16 // from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
template <typename scalar_t, typename index_t> template <typename scalar_t, typename index_t>
...@@ -26,6 +87,25 @@ void pack_vnni( ...@@ -26,6 +87,25 @@ void pack_vnni(
int K, int K,
int ld_src, int ld_src,
int ld_dst) { int ld_dst) {
#if defined(CPU_CAPABILITY_AVX512)
const int NB = div_up(N, 16);
const int KB = K / 32; // no remainder
const bool is_indexed = ind != nullptr;
for (int nb = 0; nb < NB; ++nb) {
for (int kb = 0; kb < KB; ++kb) {
// handle 16x512bits each block
int nb_size = std::min(N - nb * 16, 16);
pack_vnni_Nx32<scalar_t, index_t>(
/* dst */ dst + ((kb * 32) >> 1) * ld_dst * 2 + nb * 16 * 2,
/* src */ src + kb * 32 + (is_indexed ? 0 : nb * 16 * ld_src),
/* ind */ is_indexed ? ind + nb * 16 : nullptr,
/* N */ nb_size,
/* ld_src */ ld_src,
/* ld_dst */ ld_dst);
}
}
#else
for (int n = 0; n < N; ++n) { for (int n = 0; n < N; ++n) {
index_t index = get_index(ind, n); index_t index = get_index(ind, n);
for (int k = 0; k < K / 2; ++k) { for (int k = 0; k < K / 2; ++k) {
...@@ -34,6 +114,7 @@ void pack_vnni( ...@@ -34,6 +114,7 @@ void pack_vnni(
} }
} }
} }
#endif
} }
// convert to vnni format // convert to vnni format
...@@ -47,6 +128,25 @@ void pack_vnni2( ...@@ -47,6 +128,25 @@ void pack_vnni2(
int N, int N,
int ld_src, int ld_src,
int ld_dst) { int ld_dst) {
#if defined(CPU_CAPABILITY_AVX512)
const int KB = div_up(K, 2);
const int NB = N / 32; // no remainder
const bool is_indexed = ind != nullptr;
for (int kb = 0; kb < KB; ++kb) {
for (int nb = 0; nb < NB; ++nb) {
// handle 2x512bits each block
int kb_size = std::min(K - kb * 2, 2);
pack_vnni_Kx32<scalar_t, index_t>(
/* dst */ dst + ((kb * 2) >> 1) * ld_dst * 2 + nb * 32 * 2,
/* src */ src + (is_indexed ? 0 : kb * 2 * ld_src) + nb * 32,
/* ind */ is_indexed ? ind + kb * 2 : nullptr,
/* K */ kb_size,
/* ld_src */ ld_src,
/* ld_dst */ ld_dst);
}
}
#else
int k = 0; int k = 0;
for (; k < (K >> 1) * 2; k += 2) { for (; k < (K >> 1) * 2; k += 2) {
index_t index0 = get_index(ind, k + 0); index_t index0 = get_index(ind, k + 0);
...@@ -64,21 +164,17 @@ void pack_vnni2( ...@@ -64,21 +164,17 @@ void pack_vnni2(
} }
k += 2; k += 2;
} }
// TODO: check whether we can skip this! #endif
// const int padded_K = div_up(K, TILE_K) * TILE_K;
// for (; k < padded_K; ++k) {
// for (int n = 0; n < N; ++n) {
// dst[k * ld_dst + n] = static_cast<scalar_t>(0);
// }
// }
} }
template <typename scalar_t> template <typename scalar_t>
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) { inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
using Vec = at::vec::Vectorized<scalar_t>; using Vec = at::vec::Vectorized<scalar_t>;
constexpr int kVecSize = Vec::size();
const Vec data_vec = Vec(static_cast<scalar_t>(val)); const Vec data_vec = Vec(static_cast<scalar_t>(val));
int d = 0; int d = 0;
for (; d <= size - Vec::size(); d += Vec::size()) { #pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
data_vec.store(out + d); data_vec.store(out + d);
} }
if (size - d > 0) { if (size - d > 0) {
...@@ -110,9 +206,11 @@ template <typename scalar_t> ...@@ -110,9 +206,11 @@ template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) { inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
using bVec = at::vec::Vectorized<scalar_t>; using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>; using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_fvec = fVec(s); const fVec s_fvec = fVec(s);
int d = 0; int d = 0;
for (; d <= size - bVec::size(); d += bVec::size()) { #pragma GCC unroll 4
for (; d <= size - kVecSize; d += kVecSize) {
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec; fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec; fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1); bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
......
...@@ -93,6 +93,8 @@ void fused_experts_fp8_kernel_impl( ...@@ -93,6 +93,8 @@ void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2, scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp, scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2, const at::Float8_e4m3fn* __restrict__ packed_w2,
...@@ -135,6 +137,8 @@ void shared_expert_fp8_kernel_impl( ...@@ -135,6 +137,8 @@ void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output, scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0, scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2, const at::Float8_e4m3fn* __restrict__ packed_w2,
......
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#include "gemm.h" #include "gemm.h"
#include "vec.h" #include "vec.h"
// we use 4x32 for BLOCK_M
#define BLOCK_SIZE_M_SCALE 4
namespace { namespace {
template <typename scalar_t> template <typename scalar_t>
...@@ -61,33 +64,38 @@ inline void unpack_B( ...@@ -61,33 +64,38 @@ inline void unpack_B(
constexpr int BLOCK_N = block_size_n(); constexpr int BLOCK_N = block_size_n();
static_assert(BLOCK_N == 32); static_assert(BLOCK_N == 32);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 64;
#pragma GCC unroll 4
for (int k = 0; k < K2; ++k) { for (int k = 0; k < K2; ++k) {
for (int n = 0; n < N; n += 64) { // BLOCK_N = 32 __m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2);
__m512i b8 = _mm512_loadu_si512(b_ptr + k * ldb2 + n); if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2, _MM_HINT_T0);
}
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); __m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0);
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1); __m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0); __m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1); __m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale // Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0)); __m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1)); __m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0)); __m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1)); __m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd); f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd); f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd); f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd); f1_hi = _mm512_mul_ps(f1_hi, vd);
bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo); bf16_0 = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo); bf16_1 = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 0, (__m512i)bf16_0); _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 0, (__m512i)bf16_0);
_mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + n * 2 + 32, (__m512i)bf16_1); _mm512_storeu_si512(Btmp + k * ldb_tmp * 2 + 32, (__m512i)bf16_1);
}
} }
#else #else
TORCH_CHECK(false, "unpack_B: scalar path not implemented!"); TORCH_CHECK(false, "unpack_B: scalar path not implemented!");
...@@ -128,24 +136,30 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL ...@@ -128,24 +136,30 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL
constexpr int ROWS = BLOCK_M; constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16; constexpr int COLS = BLOCK_N / 16;
const int KB = div_up(K, BLOCK_K);
// prefetch distance // prefetch distance
constexpr int PREFETCH_SIZE_K = 0; constexpr int PREFETCH_SIZE_K = 64;
constexpr int PREFETCH_SIZE_KB = 1;
__m512bh va; __m512bh va;
__m512bh vb[COLS]; __m512bh vb[COLS];
__m512 vc[ROWS * COLS]; __m512 vc[ROWS * COLS];
__m512 vsum[ROWS * COLS];
// block quant scale
__m512 vscale;
auto loadc = [&](auto i) { auto loadc = [&](auto i) {
constexpr int col = i % COLS; constexpr int col = i % COLS;
if constexpr (has_bias) { if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16); vc[i] = _mm512_loadu_ps(bias + col * 16);
} else { } else {
vc[i] = _mm512_set1_ps(0.f); vc[i] = _mm512_setzero_ps();
} }
}; };
Unroll<ROWS * COLS>{}(loadc); Unroll<ROWS * COLS>{}(loadc);
const int K2 = K >> 1;
const int lda2 = lda >> 1; const int lda2 = lda >> 1;
const int ldb2 = ldb; // ldb * 2 >> 1; const int ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A); const float* a_ptr = reinterpret_cast<const float*>(A);
...@@ -155,11 +169,11 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL ...@@ -155,11 +169,11 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL
constexpr int row = i / COLS; constexpr int row = i / COLS;
constexpr int col = i % COLS; constexpr int col = i % COLS;
int idx = k * 2 / block_size_K;
const __m512 vd = _mm512_set1_ps(scale[idx]);
if constexpr (col == 0) { if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k])); va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(a_ptr + row * lda2 + k + PREFETCH_SIZE_K, _MM_HINT_T0);
}
} }
if constexpr (row == 0) { if constexpr (row == 0) {
if constexpr (col % 2 == 0) { if constexpr (col % 2 == 0) {
...@@ -167,47 +181,40 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL ...@@ -167,47 +181,40 @@ struct tinygemm_kernel_nn<at::BFloat16, at::Float8_e4m3fn, has_bias, BLOCK_M, BL
if constexpr (PREFETCH_SIZE_K > 0) { if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0); _mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
} }
vb[col + 0] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 0));
__m256i b8_0 = _mm512_extracti32x8_epi32(b8, 0); vb[col + 1] = CVT_FP8_TO_BF16(_mm512_extracti32x8_epi32(b8, 1));
__m256i b8_1 = _mm512_extracti32x8_epi32(b8, 1);
__m512bh bf16_0 = CVT_FP8_TO_BF16(b8_0);
__m512bh bf16_1 = CVT_FP8_TO_BF16(b8_1);
// Apply scale
__m512 f0_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 0));
__m512 f0_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_0, 1));
__m512 f1_lo = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 0));
__m512 f1_hi = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32((__m512i)bf16_1, 1));
f0_lo = _mm512_mul_ps(f0_lo, vd);
f0_hi = _mm512_mul_ps(f0_hi, vd);
f1_lo = _mm512_mul_ps(f1_lo, vd);
f1_hi = _mm512_mul_ps(f1_hi, vd);
vb[col + 0] = _mm512_cvtne2ps_pbh(f0_hi, f0_lo);
vb[col + 1] = _mm512_cvtne2ps_pbh(f1_hi, f1_lo);
} }
} }
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]); vsum[i] = _mm512_dpbf16_ps(vsum[i], va, vb[col]);
}; };
for (int k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k); constexpr int BLOCK_K2 = BLOCK_K >> 1;
for (int kb = 0; kb < KB; ++kb) {
int kb_start = kb * BLOCK_K2;
int kb_end = std::min(K >> 1, kb_start + BLOCK_K2);
// 1. load scale vector
vscale = _mm512_set1_ps(scale[kb]);
if constexpr (PREFETCH_SIZE_KB > 0) {
_mm_prefetch(scale + kb + PREFETCH_SIZE_KB, _MM_HINT_T0);
}
// 2. zero vsum for each block
Unroll<ROWS * COLS>{}([&](auto i) { vsum[i] = _mm512_setzero_ps(); });
// 3. accumulate across each block
for (int k = kb_start; k < kb_end; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
// 4. apply scale
Unroll<ROWS * COLS>{}([&](auto i) { vc[i] = _mm512_fmadd_ps(vsum[i], vscale, vc[i]); });
} }
auto storec = [&](auto i) { auto storec = [&](auto i) {
constexpr int row = i / COLS; constexpr int row = i / COLS;
constexpr int col = i % COLS; constexpr int col = i % COLS;
// for COLS = 1, 3 use 256bit store // for COLS = 2,4 use 512bit store
// for COLS = 2, 4 use 512bit store if constexpr (col % 2 == 0) {
if constexpr (COLS % 2 == 0) { _mm512_storeu_si512(
if constexpr (col % 2 == 0) { reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
_mm512_storeu_si512( (__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
} }
}; };
Unroll<ROWS * COLS>{}(storec); Unroll<ROWS * COLS>{}(storec);
...@@ -266,22 +273,18 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> { ...@@ -266,22 +273,18 @@ struct brgemm<at::BFloat16, at::Float8_e4m3fn, has_bias> {
int ldc) { int ldc) {
constexpr int BLOCK_N = block_size_n(); constexpr int BLOCK_N = block_size_n();
// [BLOCK_K, BLOCK_N] -> [BLOCK_K / 2, BLOCK_N * 2] // [K, BLOCK_N] -> [K / 2, BLOCK_N * 2]
const int ldb_tmp = block_size_n(); const int ldb_tmp = BLOCK_N;
static_assert(BLOCK_K == 128);
// accumulate across K per BLOCK_K
for (int k = 0; k < K; k += BLOCK_K) { for (int k = 0; k < K; k += BLOCK_K) {
int kb_size = std::min(BLOCK_K, K - k); int kb_size = std::min(BLOCK_K, K - k);
int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128 int idx = k >> 7; // k / BLOCK_K where BLOCK_K = 128
unpack_B(Btmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]); unpack_B(Btmp + k * ldb_tmp, B + k * ldb, N, kb_size, ldb, ldb_tmp, scale[idx]);
const bool add_C = (k != 0);
at::native::cpublas::brgemm(M, N, kb_size, lda, ldb_tmp, BLOCK_N, add_C, A + k, Btmp, Ctmp);
} }
at::native::cpublas::brgemm(M, N, K, lda, ldb_tmp, BLOCK_N, /* add_C */ false, A, Btmp, Ctmp);
// copy from Ctmp to C // copy from Ctmp to C
for (int m = 0; m < M; ++m) { for (int m = 0; m < M; ++m) {
if constexpr (has_bias) { if constexpr (has_bias) {
...@@ -328,34 +331,18 @@ void tinygemm_kernel( ...@@ -328,34 +331,18 @@ void tinygemm_kernel(
int64_t nb_size = std::min(BLOCK_N, N - nb_start); int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) { switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12: case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32); LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break; break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22: case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32); LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break; break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32: case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32); LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break; break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42: case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32); LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break; break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default: default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size"); TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
} }
...@@ -370,14 +357,16 @@ void fp8_scaled_mm_kernel_impl( ...@@ -370,14 +357,16 @@ void fp8_scaled_mm_kernel_impl(
const at::Float8_e4m3fn* __restrict__ mat2, const at::Float8_e4m3fn* __restrict__ mat2,
const float* __restrict__ scales2, const float* __restrict__ scales2,
const float* __restrict__ bias, const float* __restrict__ bias,
scalar_t* __restrict__ buffer,
int64_t M, int64_t M,
int64_t N, int64_t N,
int64_t K, int64_t K,
int64_t mat1_strideM, int64_t mat1_strideM,
int64_t out_strideM, int64_t out_strideM,
int64_t block_size_N, int64_t block_size_N,
int64_t block_size_K) { int64_t block_size_K,
constexpr int64_t BLOCK_M = block_size_m(); int64_t buffer_size_per_thread) {
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_N = block_size_n(); constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M); const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N); const int64_t NB = div_up(N, BLOCK_N);
...@@ -393,10 +382,9 @@ void fp8_scaled_mm_kernel_impl( ...@@ -393,10 +382,9 @@ void fp8_scaled_mm_kernel_impl(
int64_t mb{0}, nb{0}; int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB); data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate int tid = at::get_thread_num();
alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; scalar_t* __restrict__ Btmp = buffer + tid * buffer_size_per_thread;
// for brgemm when mat2 is float8_e4m3 float* __restrict__ Ctmp = (float*)((void*)(Btmp + BLOCK_N * K));
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
UNUSED(i); UNUSED(i);
...@@ -507,6 +495,7 @@ at::Tensor fp8_scaled_mm_cpu( ...@@ -507,6 +495,7 @@ at::Tensor fp8_scaled_mm_cpu(
int64_t block_size_N = block_size[0]; int64_t block_size_N = block_size[0];
int64_t block_size_K = block_size[1]; int64_t block_size_K = block_size[1];
constexpr int64_t BLOCK_M = block_size_m() * BLOCK_SIZE_M_SCALE;
constexpr int64_t BLOCK_N = block_size_n(); constexpr int64_t BLOCK_N = block_size_n();
TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N"); TORCH_CHECK(block_size_N % BLOCK_N == 0, "fp8_scaled_mm_cpu: expect block_size_N to be multiples of BLOCK_N");
TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K"); TORCH_CHECK(block_size_K == BLOCK_K, "fp8_scaled_mm_cpu: expect block_size_K equals to BLOCK_K");
...@@ -531,6 +520,12 @@ at::Tensor fp8_scaled_mm_cpu( ...@@ -531,6 +520,12 @@ at::Tensor fp8_scaled_mm_cpu(
bias_data = bias.value().data_ptr<float>(); bias_data = bias.value().data_ptr<float>();
} }
// Btmp : [T, BLOCK_N * K]
// Ctmp : [T, BLOCK_M * BLOCK_N]
int num_threads = at::get_num_threads();
int64_t size_per_thread = BLOCK_N * K + BLOCK_M * BLOCK_N * 2;
auto buffer = at::empty({num_threads, size_per_thread}, mat1.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] { AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "fp8_scaled_mm_kernel_impl", [&] {
fp8_scaled_mm_kernel_impl<scalar_t>( fp8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(), out.data_ptr<scalar_t>(),
...@@ -538,13 +533,15 @@ at::Tensor fp8_scaled_mm_cpu( ...@@ -538,13 +533,15 @@ at::Tensor fp8_scaled_mm_cpu(
packed_w.data_ptr<at::Float8_e4m3fn>(), packed_w.data_ptr<at::Float8_e4m3fn>(),
scales2.data_ptr<float>(), scales2.data_ptr<float>(),
bias_data, bias_data,
buffer.data_ptr<scalar_t>(),
M, M,
N, N,
K, K,
mat1_strideM, mat1_strideM,
out_strideM, out_strideM,
block_size_N, block_size_N,
block_size_K); block_size_K,
size_per_thread);
}); });
return out; return out;
......
...@@ -33,11 +33,11 @@ void initialize(int64_t size, int64_t rank) { ...@@ -33,11 +33,11 @@ void initialize(int64_t size, int64_t rank) {
world_rank = rank; world_rank = rank;
is_initialized = true; is_initialized = true;
auto addr_string = std::getenv("MASTER_ADDR"); const char* addr_string = std::getenv("MASTER_ADDR");
if (addr_string == NULL) { if (addr_string == NULL) {
addr_string = ""; addr_string = "";
} }
auto port_string = std::getenv("MASTER_PORT"); const char* port_string = std::getenv("MASTER_PORT");
if (port_string == NULL) { if (port_string == NULL) {
port_string = ""; port_string = "";
} }
......
...@@ -1080,7 +1080,8 @@ at::Tensor fused_experts_cpu( ...@@ -1080,7 +1080,8 @@ at::Tensor fused_experts_cpu(
// 6. As_tmp : [M * topk] // 6. As_tmp : [M * topk]
// //
// for fp8 w8a16: // for fp8 w8a16:
// 7. intermediate_cache1 : [M * topk, 2N] // 7. intermediate_cache0 : [M * topk, 2N]
// 8. B_tmp : [T, BLOCK_N, std::max(K, N)]
// //
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 + int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) + num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
...@@ -1090,7 +1091,7 @@ at::Tensor fused_experts_cpu( ...@@ -1090,7 +1091,7 @@ at::Tensor fused_experts_cpu(
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float); buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
} }
if (use_fp8_w8a16) { if (use_fp8_w8a16) {
buffer_size_nbytes += M * topk * 2 * N * 2; buffer_size_nbytes += M * topk * 2 * N * 2 + num_threads * BLOCK_N * std::max(K, N) * 2;
} }
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
...@@ -1136,7 +1137,9 @@ at::Tensor fused_experts_cpu( ...@@ -1136,7 +1137,9 @@ at::Tensor fused_experts_cpu(
} else if (use_fp8_w8a16) { } else if (use_fp8_w8a16) {
// here we just ignore C_tmp as it is not used // here we just ignore C_tmp as it is not used
scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K)); scalar_t* __restrict__ A_tmp = (scalar_t*)((void*)(intermediate_cache2 + M * topk * K));
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(A_tmp + num_threads * BLOCK_M * K)); float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * topk * 2 * N));
CHECK_MOE_SCALES_FP8(1, 2); CHECK_MOE_SCALES_FP8(1, 2);
fused_experts_fp8_kernel_impl( fused_experts_fp8_kernel_impl(
...@@ -1145,6 +1148,8 @@ at::Tensor fused_experts_cpu( ...@@ -1145,6 +1148,8 @@ at::Tensor fused_experts_cpu(
intermediate_cache1, intermediate_cache1,
intermediate_cache2, intermediate_cache2,
A_tmp, A_tmp,
B_tmp,
C_tmp,
hidden_states.data_ptr<scalar_t>(), hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<at::Float8_e4m3fn>(), packed_w1.data_ptr<at::Float8_e4m3fn>(),
packed_w2.data_ptr<at::Float8_e4m3fn>(), packed_w2.data_ptr<at::Float8_e4m3fn>(),
...@@ -1258,6 +1263,7 @@ at::Tensor shared_expert_cpu( ...@@ -1258,6 +1263,7 @@ at::Tensor shared_expert_cpu(
// //
// for fp8 w8a16: // for fp8 w8a16:
// 5. intermediate_cache0 : [M, 2N] // 5. intermediate_cache0 : [M, 2N]
// 6. B_tmp: [T, BLOCK_M, max(K, N)]
// //
int num_threads = at::get_num_threads(); int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float); int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
...@@ -1266,7 +1272,7 @@ at::Tensor shared_expert_cpu( ...@@ -1266,7 +1272,7 @@ at::Tensor shared_expert_cpu(
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float); buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
} }
if (use_fp8_w8a16) { if (use_fp8_w8a16) {
buffer_size_nbytes += M * 2 * N * 2; buffer_size_nbytes += M * 2 * N * 2 + num_threads * BLOCK_M * std::max(K, N) * 2;
} }
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar)); auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
...@@ -1301,12 +1307,15 @@ at::Tensor shared_expert_cpu( ...@@ -1301,12 +1307,15 @@ at::Tensor shared_expert_cpu(
K); K);
} else if (use_fp8_w8a16) { } else if (use_fp8_w8a16) {
scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N)); scalar_t* __restrict__ intermediate_cache0 = (scalar_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
scalar_t* __restrict__ B_tmp = (scalar_t*)((void*)(intermediate_cache0 + M * 2 * N));
CHECK_MOE_SCALES_FP8(0, 1); CHECK_MOE_SCALES_FP8(0, 1);
shared_expert_fp8_kernel_impl<scalar_t>( shared_expert_fp8_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(), out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache0, intermediate_cache0,
intermediate_cache1, intermediate_cache1,
B_tmp,
C_tmp,
hidden_states.data_ptr<scalar_t>(), hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<at::Float8_e4m3fn>(), packed_w1.data_ptr<at::Float8_e4m3fn>(),
packed_w2.data_ptr<at::Float8_e4m3fn>(), packed_w2.data_ptr<at::Float8_e4m3fn>(),
......
...@@ -142,6 +142,8 @@ void fused_experts_fp8_kernel_impl( ...@@ -142,6 +142,8 @@ void fused_experts_fp8_kernel_impl(
scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2, scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp, scalar_t* __restrict__ A_tmp,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2, const at::Float8_e4m3fn* __restrict__ packed_w2,
...@@ -178,9 +180,6 @@ void fused_experts_fp8_kernel_impl( ...@@ -178,9 +180,6 @@ void fused_experts_fp8_kernel_impl(
int tid = at::get_thread_num(); int tid = at::get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K; scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
bool is_brgemm_used = false; bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
...@@ -212,8 +211,8 @@ void fused_experts_fp8_kernel_impl( ...@@ -212,8 +211,8 @@ void fused_experts_fp8_kernel_impl(
/* A */ A, /* A */ A,
/* B */ B, /* B */ B,
/* C */ ic0 + offset * 2 * N + nb * BLOCK_N, /* C */ ic0 + offset * 2 * N + nb * BLOCK_N,
/* Btmp */ Btmp, /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ Ctmp, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs, /* scale */ Bs,
/* M */ m_size, /* M */ m_size,
/* N */ n_size, /* N */ n_size,
...@@ -250,9 +249,8 @@ void fused_experts_fp8_kernel_impl( ...@@ -250,9 +249,8 @@ void fused_experts_fp8_kernel_impl(
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; int tid = at::get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_K];
bool is_brgemm_used = false; bool is_brgemm_used = false;
...@@ -281,8 +279,8 @@ void fused_experts_fp8_kernel_impl( ...@@ -281,8 +279,8 @@ void fused_experts_fp8_kernel_impl(
/* A */ A, /* A */ A,
/* B */ B, /* B */ B,
/* C */ C, /* C */ C,
/* Btmp */ Btmp, /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ Ctmp, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ Bs, /* scale */ Bs,
/* M */ m_size, /* M */ m_size,
/* N */ n_size, /* N */ n_size,
...@@ -323,6 +321,8 @@ void fused_experts_fp8_kernel_impl( ...@@ -323,6 +321,8 @@ void fused_experts_fp8_kernel_impl(
TYPE* __restrict__ ic1, \ TYPE* __restrict__ ic1, \
TYPE* __restrict__ ic2, \ TYPE* __restrict__ ic2, \
TYPE* __restrict__ A_tmp, \ TYPE* __restrict__ A_tmp, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \ const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \ const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \ const at::Float8_e4m3fn* __restrict__ packed_w2, \
...@@ -349,6 +349,8 @@ void shared_expert_fp8_kernel_impl( ...@@ -349,6 +349,8 @@ void shared_expert_fp8_kernel_impl(
scalar_t* __restrict__ output, scalar_t* __restrict__ output,
scalar_t* __restrict__ ic0, scalar_t* __restrict__ ic0,
scalar_t* __restrict__ ic1, scalar_t* __restrict__ ic1,
scalar_t* __restrict__ B_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input, const scalar_t* __restrict__ input,
const at::Float8_e4m3fn* __restrict__ packed_w1, const at::Float8_e4m3fn* __restrict__ packed_w1,
const at::Float8_e4m3fn* __restrict__ packed_w2, const at::Float8_e4m3fn* __restrict__ packed_w2,
...@@ -373,8 +375,7 @@ void shared_expert_fp8_kernel_impl( ...@@ -373,8 +375,7 @@ void shared_expert_fp8_kernel_impl(
const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M); const bool use_brgemm = can_use_brgemm<at::Float8_e4m3fn>(M);
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) { at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K]; int tid = at::get_thread_num();
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB; int64_t mb = i / NB;
...@@ -386,8 +387,8 @@ void shared_expert_fp8_kernel_impl( ...@@ -386,8 +387,8 @@ void shared_expert_fp8_kernel_impl(
/* A */ input + mb * BLOCK_M * K, /* A */ input + mb * BLOCK_M * K,
/* B */ packed_w1 + nb * BLOCK_N * K, /* B */ packed_w1 + nb * BLOCK_N * K,
/* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N, /* C */ ic0 + mb * BLOCK_M * 2 * N + nb * BLOCK_N,
/* Btmp */ Btmp, /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ Ctmp, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K, /* scale */ w1s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size, /* M */ m_size,
/* N */ n_size, /* N */ n_size,
...@@ -421,9 +422,8 @@ void shared_expert_fp8_kernel_impl( ...@@ -421,9 +422,8 @@ void shared_expert_fp8_kernel_impl(
// parallel on [MB2, NB2] // parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) { at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
alignas(64) scalar_t Btmp[BLOCK_K * BLOCK_N]; int tid = at::get_thread_num();
alignas(64) scalar_t C[BLOCK_M * BLOCK_K]; alignas(64) scalar_t C[BLOCK_M * BLOCK_K];
alignas(64) float Ctmp[BLOCK_M * BLOCK_K];
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2; int64_t mb = i / NB2;
...@@ -436,8 +436,8 @@ void shared_expert_fp8_kernel_impl( ...@@ -436,8 +436,8 @@ void shared_expert_fp8_kernel_impl(
/* A */ ic1 + mb * BLOCK_M * N, /* A */ ic1 + mb * BLOCK_M * N,
/* B */ packed_w2 + nb * BLOCK_N * N, /* B */ packed_w2 + nb * BLOCK_N * N,
/* C */ C, /* C */ C,
/* Btmp */ Btmp, /* Btmp */ B_tmp + tid * BLOCK_N * std::max(K, N),
/* Ctmp */ Ctmp, /* Ctmp */ C_tmp + tid * 2 * BLOCK_M * BLOCK_N,
/* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K, /* scale */ w2s + (nb / blocks_n_per_group) * scale_size_K,
/* M */ m_size, /* M */ m_size,
/* N */ n_size, /* N */ n_size,
...@@ -467,6 +467,8 @@ void shared_expert_fp8_kernel_impl( ...@@ -467,6 +467,8 @@ void shared_expert_fp8_kernel_impl(
TYPE* __restrict__ output, \ TYPE* __restrict__ output, \
TYPE* __restrict__ ic0, \ TYPE* __restrict__ ic0, \
TYPE* __restrict__ ic1, \ TYPE* __restrict__ ic1, \
TYPE* __restrict__ B_tmp, \
float* __restrict__ C_tmp, \
const TYPE* __restrict__ input, \ const TYPE* __restrict__ input, \
const at::Float8_e4m3fn* __restrict__ packed_w1, \ const at::Float8_e4m3fn* __restrict__ packed_w1, \
const at::Float8_e4m3fn* __restrict__ packed_w2, \ const at::Float8_e4m3fn* __restrict__ packed_w2, \
......
...@@ -72,6 +72,7 @@ void rmsnorm_kernel_impl( ...@@ -72,6 +72,7 @@ void rmsnorm_kernel_impl(
const scalar_t* __restrict__ weight, const scalar_t* __restrict__ weight,
int64_t batch_size, int64_t batch_size,
int64_t hidden_size, int64_t hidden_size,
int64_t input_strideN,
float eps = 1e-5) { float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>; using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>; using fVec = at::vec::Vectorized<float>;
...@@ -81,7 +82,7 @@ void rmsnorm_kernel_impl( ...@@ -81,7 +82,7 @@ void rmsnorm_kernel_impl(
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
// local ptrs // local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size; scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * hidden_size; const scalar_t* __restrict__ input_ptr = input + i * input_strideN;
fVec sum_fvec = fVec(float(0)); fVec sum_fvec = fVec(float(0));
float sum_val = float(0); float sum_val = float(0);
...@@ -140,6 +141,7 @@ void fused_add_rmsnorm_kernel_impl( ...@@ -140,6 +141,7 @@ void fused_add_rmsnorm_kernel_impl(
float* __restrict__ buffer, float* __restrict__ buffer,
int64_t batch_size, int64_t batch_size,
int64_t hidden_size, int64_t hidden_size,
int64_t input_strideN,
float eps = 1e-5) { float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>; using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>; using fVec = at::vec::Vectorized<float>;
...@@ -151,7 +153,7 @@ void fused_add_rmsnorm_kernel_impl( ...@@ -151,7 +153,7 @@ void fused_add_rmsnorm_kernel_impl(
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
// local ptrs // local ptrs
scalar_t* __restrict__ input_ptr = input + i * hidden_size; scalar_t* __restrict__ input_ptr = input + i * input_strideN;
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size; scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
fVec sum_fvec = fVec(float(0)); fVec sum_fvec = fVec(float(0));
...@@ -242,7 +244,7 @@ at::Tensor l2norm_cpu(at::Tensor& input, double eps) { ...@@ -242,7 +244,7 @@ at::Tensor l2norm_cpu(at::Tensor& input, double eps) {
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight})); RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
CHECK_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_INPUT(weight); CHECK_INPUT(weight);
CHECK_DIM(2, input); CHECK_DIM(2, input);
CHECK_DIM(1, weight); CHECK_DIM(1, weight);
...@@ -250,6 +252,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { ...@@ -250,6 +252,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
int64_t batch_size = input.size(0); int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1); int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input); at::Tensor output = at::empty_like(input);
int64_t input_strideN = input.stride(0);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] { AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
rmsnorm_kernel_impl<scalar_t>( rmsnorm_kernel_impl<scalar_t>(
...@@ -258,6 +261,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { ...@@ -258,6 +261,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
weight.data_ptr<scalar_t>(), weight.data_ptr<scalar_t>(),
batch_size, batch_size,
hidden_size, hidden_size,
input_strideN,
eps); eps);
}); });
return output; return output;
...@@ -268,7 +272,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) { ...@@ -268,7 +272,7 @@ at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
// weight : {hidden_size} // weight : {hidden_size}
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) { void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight})); RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
CHECK_INPUT(input); CHECK_LAST_DIM_CONTIGUOUS_INPUT(input);
CHECK_INPUT(residual); CHECK_INPUT(residual);
CHECK_INPUT(weight); CHECK_INPUT(weight);
CHECK_DIM(2, input); CHECK_DIM(2, input);
...@@ -279,6 +283,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& ...@@ -279,6 +283,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor&
CHECK_EQ(input.size(1), weight.size(0)); CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0); int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1); int64_t hidden_size = input.size(1);
int64_t input_strideN = input.stride(0);
// allocate temp buffer to store x in float32 per thread // allocate temp buffer to store x in float32 per thread
// TODO: implement a singleton for context // TODO: implement a singleton for context
...@@ -293,6 +298,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& ...@@ -293,6 +298,7 @@ void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor&
buffer.data_ptr<float>(), buffer.data_ptr<float>(),
batch_size, batch_size,
hidden_size, hidden_size,
input_strideN,
eps); eps);
}); });
} }
...@@ -162,6 +162,7 @@ void segment_gemm_kernel_impl( ...@@ -162,6 +162,7 @@ void segment_gemm_kernel_impl(
const at::Float8_e4m3fn* __restrict__ B1, const at::Float8_e4m3fn* __restrict__ B1,
const float* __restrict__ Bs0, const float* __restrict__ Bs0,
const float* __restrict__ Bs1, const float* __restrict__ Bs1,
scalar_t* __restrict__ Btmp,
int64_t M, int64_t M,
int64_t N0, int64_t N0,
int64_t N1, int64_t N1,
...@@ -185,10 +186,9 @@ void segment_gemm_kernel_impl( ...@@ -185,10 +186,9 @@ void segment_gemm_kernel_impl(
int64_t mb{0}, nb{0}; int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB); data_index_init(begin, mb, MB, nb, NB);
int tid = at::get_thread_num();
// for brgemm, use float32 for accumulate // for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N]; alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
// for brgemm when mat2 is float8_e4m3
alignas(64) scalar_t Btmp[BLOCK_N * BLOCK_K];
for (int64_t i = begin; i < end; ++i) { for (int64_t i = begin; i < end; ++i) {
UNUSED(i); UNUSED(i);
...@@ -209,7 +209,7 @@ void segment_gemm_kernel_impl( ...@@ -209,7 +209,7 @@ void segment_gemm_kernel_impl(
/* A */ A + mb_start * K, /* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */, /* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start, /* C */ C + mb_start * ldc + local_nb_start,
/* Btmp*/ Btmp, /* Btmp*/ Btmp + tid * BLOCK_N * K,
/* Ctmp*/ Ctmp, /* Ctmp*/ Ctmp,
/* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K, /* Bs */ Bs + (new_nb / blocks_n_per_group) * scale_size_K,
/* M */ mb_size, /* M */ mb_size,
...@@ -541,6 +541,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope( ...@@ -541,6 +541,10 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K)); CHECK_EQ(q_a_proj_s.size(1), div_up(hidden_size, block_size_K));
CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N)); CHECK_EQ(kv_a_proj_s.size(0), div_up(kv_lora_rank + qk_rope_head_dim, block_size_N));
CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K)); CHECK_EQ(kv_a_proj_s.size(1), div_up(hidden_size, block_size_K));
const int BLOCK_N = block_size_n();
const int num_threads = at::get_num_threads();
auto buffer = at::empty({num_threads, BLOCK_N * hidden_size}, options);
segment_gemm_kernel_impl<scalar_t>( segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(), qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(), k_input.data_ptr<scalar_t>(),
...@@ -549,6 +553,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope( ...@@ -549,6 +553,7 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(), kv_a_proj_weight.data_ptr<at::Float8_e4m3fn>(),
q_a_proj_s.data_ptr<float>(), q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(), kv_a_proj_s.data_ptr<float>(),
buffer.data_ptr<scalar_t>(),
num_seqs, num_seqs,
q_lora_rank, q_lora_rank,
kv_lora_rank + qk_rope_head_dim, kv_lora_rank + qk_rope_head_dim,
...@@ -624,3 +629,74 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope( ...@@ -624,3 +629,74 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
return std::make_tuple(q_input, k_input, v_input); return std::make_tuple(q_input, k_input, v_input);
} }
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
at::Tensor& hidden_states,
at::Tensor& qkv_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> qkv_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size,
int64_t q_lora_rank,
int64_t kv_lora_rank,
int64_t qk_rope_head_dim) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope_fused_weight",
std::vector<c10::IValue>({hidden_states, qkv_a_proj_weight, q_b_proj_weight, w_kc}));
int64_t hidden_size = hidden_states.size(1);
CHECK_EQ(qkv_a_proj_weight.size(0), q_lora_rank + kv_lora_rank + qk_rope_head_dim);
CHECK_EQ(qkv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
std::vector<at::Tensor> weight_chunks =
at::split(qkv_a_proj_weight, {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
at::Tensor q_a_proj_weight = weight_chunks[0];
at::Tensor kv_a_proj_weight = weight_chunks[1];
at::Tensor q_a_proj_s;
at::Tensor kv_a_proj_s;
if (use_int8_w8a8) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for int8 w8a8.");
std::vector<at::Tensor> scale_chunks =
at::split(qkv_a_proj_scale.value(), {q_lora_rank, kv_lora_rank + qk_rope_head_dim}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
if (use_fp8_w8a16) {
TORCH_CHECK(qkv_a_proj_scale.has_value(), "missing qkv_a_proj_scale for fp8 w8a16.");
int64_t block_size_N = block_size.value()[0];
int64_t q_a_proj_s_dim0 = div_up(q_lora_rank, block_size_N);
int64_t kv_a_proj_s_dim0 = div_up(kv_lora_rank + qk_rope_head_dim, block_size_N);
std::vector<at::Tensor> scale_chunks = at::split(qkv_a_proj_scale.value(), {q_a_proj_s_dim0, kv_a_proj_s_dim0}, 0);
q_a_proj_s = scale_chunks[0];
kv_a_proj_s = scale_chunks[1];
}
return qkv_proj_with_rope(
hidden_states,
q_a_proj_weight,
q_b_proj_weight,
kv_a_proj_weight,
w_kc,
q_a_layernorm_weight,
kv_a_layernorm_weight,
positions,
cos_sin_cache,
eps,
use_int8_w8a8,
use_fp8_w8a16,
q_a_proj_s,
q_b_proj_scale,
kv_a_proj_s,
is_vnni,
block_size);
}
...@@ -54,7 +54,8 @@ void shared_open(SharedData* data, const char* name, size_t nbytes) { ...@@ -54,7 +54,8 @@ void shared_open(SharedData* data, const char* name, size_t nbytes) {
void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) { void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) {
int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR); int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
if (d != -1) { if (d != -1) {
if (nbytes = write(d, bytes, nbytes)) { nbytes = write(d, bytes, nbytes);
if (nbytes > 0) {
shared_open(data, name, nbytes); shared_open(data, name, nbytes);
} }
} else { } else {
...@@ -391,7 +392,7 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, ...@@ -391,7 +392,7 @@ void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer,
static bool is_initialized = false; static bool is_initialized = false;
static int world_rank; static int world_rank;
void shm_initialize(int size, int rank, char* addr_string, char* port_string) { void shm_initialize(int size, int rank, const char* addr_string, const char* port_string) {
if (is_initialized) { if (is_initialized) {
return; return;
} }
...@@ -409,7 +410,7 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) { ...@@ -409,7 +410,7 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) {
struct allreduce_workspace* workspace_buf; struct allreduce_workspace* workspace_buf;
struct allreduce_workspace* workspace_buf_other; struct allreduce_workspace* workspace_buf_other;
workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace)); workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace));
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank); snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, rank);
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace)); shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes; workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done; workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
...@@ -425,7 +426,7 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) { ...@@ -425,7 +426,7 @@ void shm_initialize(int size, int rank, char* addr_string, char* port_string) {
// map shm of all ranks // map shm of all ranks
for (int i = 0; i < size; i++) { for (int i = 0; i < size; i++) {
if (i != rank) { if (i != rank) {
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i); snprintf(shm_name, NAME_BUF_SIZE, "%.900s_%d", shm_name_prefix, i);
// printf("open %s, %d\n", shm_name, rank); // printf("open %s, %d\n", shm_name, rank);
do { do {
shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace)); shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace));
...@@ -447,13 +448,13 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes) { ...@@ -447,13 +448,13 @@ static void parallel_memcpy(void* to, void* from, size_t n_bytes) {
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES); auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
// process aligned part // process aligned part
#pragma omp parallel for #pragma omp parallel for
for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) { for (size_t i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i)); auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
_mm256_storeu_si256((__m256i*)((char*)to + i), val); _mm256_storeu_si256((__m256i*)((char*)to + i), val);
} }
// process remaining part // process remaining part
for (int i = aligned_bytes; i < n_bytes; i++) { for (size_t i = aligned_bytes; i < n_bytes; i++) {
*((char*)to + i) = *((char*)from + i); *((char*)to + i) = *((char*)from + i);
} }
} }
...@@ -481,7 +482,9 @@ void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, siz ...@@ -481,7 +482,9 @@ void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, siz
static int current_buffer = 0; static int current_buffer = 0;
static int state_idx = 0; static int state_idx = 0;
enum coll_state copy_current, copy_next; // init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
switch (state_idx) { switch (state_idx) {
case 0: case 0:
...@@ -526,7 +529,10 @@ void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_ ...@@ -526,7 +529,10 @@ void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_
static int current_buffer = 0; static int current_buffer = 0;
static int state_idx = 0; static int state_idx = 0;
enum coll_state copy_current, copy_next, reduce_current; // init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allreduce_naive__copy_in_done;
enum coll_state reduce_current = coll_allreduce_naive__reduce_done;
enum coll_state copy_next = coll_alt1_allreduce_naive__copy_in_done;
// similar to symmetric_naive_allreduce, but here we only need two sets of // similar to symmetric_naive_allreduce, but here we only need two sets of
// states, because distributed naive reduce has two barriers in the algorithm // states, because distributed naive reduce has two barriers in the algorithm
...@@ -601,7 +607,9 @@ void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_ ...@@ -601,7 +607,9 @@ void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_
static int current_buffer = 0; static int current_buffer = 0;
static int state_idx = 0; static int state_idx = 0;
enum coll_state copy_current, copy_next; // init states to case 0 to get rid of "maybe-uninitialized" warning.
enum coll_state copy_current = coll_allgather_naive__copy_in_done;
enum coll_state copy_next = coll_alt1_allgather_naive__copy_in_done;
switch (state_idx) { switch (state_idx) {
case 0: case 0:
...@@ -621,7 +629,6 @@ void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_ ...@@ -621,7 +629,6 @@ void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_
} }
state_idx = (state_idx + 1) % 3; state_idx = (state_idx + 1) % 3;
int data_size = chunk_size / chunk_el;
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size); parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release); std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current; workspace[world_rank]->states[state_group] = copy_current;
...@@ -644,7 +651,7 @@ torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, s ...@@ -644,7 +651,7 @@ torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, s
auto data_ptr = (char*)(data.data_ptr()); auto data_ptr = (char*)(data.data_ptr());
auto result_ptr = (char*)(result.data_ptr()); auto result_ptr = (char*)(result.data_ptr());
for (int i = 0; i < dim_count; i++) { for (int i = 0; i < dim_count; i++) {
for (int offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) { for (size_t offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) {
size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset; size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset;
size_t chunk_el = chunk_size / dtype_size; size_t chunk_el = chunk_size / dtype_size;
naive_all_gather( naive_all_gather(
......
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#ifndef __SHM_COLLECTIVES__ #ifndef __SHM_COLLECTIVES__
#define __SHM_COLLECTIVES__ #define __SHM_COLLECTIVES__
#define VECTOR_LENGTH_IN_BYTES 32 #define VECTOR_LENGTH_IN_BYTES 32
void shm_initialize(int size, int rank, char* addr_string, char* port_string); void shm_initialize(int size, int rank, const char* addr_string, const char* port_string);
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size); void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size); torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size);
#endif #endif
...@@ -534,7 +534,25 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu( ...@@ -534,7 +534,25 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
int64_t topk, int64_t topk,
bool renormalize, bool renormalize,
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group) { int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded) {
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
// For now, we just check them as default value.
TORCH_CHECK(
num_fused_shared_experts == 0,
"num_fused_shared_experts must be 0 default value, got: ",
num_fused_shared_experts);
TORCH_CHECK(
!routed_scaling_factor.has_value() || routed_scaling_factor.value() == 1.0f,
"routed_scaling_factor must be None or 1.0f default value, got: ",
routed_scaling_factor.value());
TORCH_CHECK(
!num_token_non_padded.has_value(),
"num_token_non_padded must be None default value, got: ",
num_token_non_padded.value());
RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output})); RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output); CHECK_INPUT(gating_output);
...@@ -594,7 +612,21 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu( ...@@ -594,7 +612,21 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
int64_t topk, int64_t topk,
bool renormalize, bool renormalize,
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group) { int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded) {
// TODO: Will support num_fused_shared_experts, routed_scaling_factor and num_token_non_padded.
// For now, we just check them as default value.
TORCH_CHECK(
num_fused_shared_experts == 0,
"num_fused_shared_experts must be 0 default value, got: ",
num_fused_shared_experts);
TORCH_CHECK(
!num_token_non_padded.has_value(),
"num_token_non_padded must be None default value, got: ",
num_token_non_padded.value());
RECORD_FUNCTION( RECORD_FUNCTION(
"sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias})); "sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));
......
...@@ -44,7 +44,10 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu( ...@@ -44,7 +44,10 @@ std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
int64_t topk, int64_t topk,
bool renormalize, bool renormalize,
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group); int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded);
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu( std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor& hidden_states, at::Tensor& hidden_states,
...@@ -53,7 +56,10 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu( ...@@ -53,7 +56,10 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
int64_t topk, int64_t topk,
bool renormalize, bool renormalize,
int64_t num_expert_group, int64_t num_expert_group,
int64_t topk_group); int64_t topk_group,
int64_t num_fused_shared_experts,
std::optional<double> routed_scaling_factor,
std::optional<at::Tensor> num_token_non_padded);
// attention // attention
void decode_attention_cpu( void decode_attention_cpu(
...@@ -182,6 +188,26 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope( ...@@ -182,6 +188,26 @@ std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
bool is_vnni, bool is_vnni,
std::optional<std::vector<int64_t>> block_size); std::optional<std::vector<int64_t>> block_size);
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope_fused_weight(
at::Tensor& hidden_states,
at::Tensor& qkv_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
bool use_fp8_w8a16,
std::optional<at::Tensor> qkv_a_proj_scale,
std::optional<at::Tensor> q_b_proj_scale,
bool is_vnni,
std::optional<std::vector<int64_t>> block_size,
int64_t q_lora_rank,
int64_t kv_lora_rank,
int64_t qk_rope_head_dim);
// shared memory init // shared memory init
void initialize(int64_t size, int64_t rank); void initialize(int64_t size, int64_t rank);
...@@ -221,13 +247,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -221,13 +247,15 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu); m.impl("topk_softmax_cpu", torch::kCPU, &topk_softmax_cpu);
m.def( m.def(
"grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, " "grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, int topk, bool renormalize, int num_expert_group, "
"int topk_group) -> (Tensor, Tensor)"); "int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, Tensor? num_token_non_padded) -> "
"(Tensor, Tensor)");
m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu); m.impl("grouped_topk_cpu", torch::kCPU, &grouped_topk_cpu);
// biased group topk // biased group topk
m.def( m.def(
"biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool " "biased_grouped_topk_cpu(Tensor hidden_states, Tensor gating_output, Tensor correction_bias, int topk, bool "
"renormalize, int num_expert_group, int topk_group) -> (Tensor, Tensor)"); "renormalize, int num_expert_group, int topk_group, int num_fused_shared_experts, float? routed_scaling_factor, "
"Tensor? num_token_non_padded) -> (Tensor, Tensor)");
m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu); m.impl("biased_grouped_topk_cpu", torch::kCPU, &biased_grouped_topk_cpu);
// decode // decode
...@@ -294,6 +322,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -294,6 +322,14 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
"q_b_proj_scale, Tensor? " "q_b_proj_scale, Tensor? "
"kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)"); "kv_a_proj_scale, bool is_vnni, int[]? block_size) -> (Tensor, Tensor, Tensor)");
m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope); m.impl("qkv_proj_with_rope", torch::kCPU, &qkv_proj_with_rope);
m.def(
"qkv_proj_with_rope_fused_weight(Tensor hidden_states, Tensor qkv_a_proj_weight, Tensor q_b_proj_weight, "
"Tensor w_kc, Tensor q_a_layernorm_weight, Tensor kv_a_layernorm_weight, Tensor positions, "
"Tensor cos_sin_cache, float eps, bool use_int8_w8a8, bool use_fp8_w8a16, Tensor? qkv_a_proj_scale, Tensor? "
"q_b_proj_scale,"
"bool is_vnni, int[]? block_size, int q_lora_rank, int kv_lora_rank,"
"int qk_rope_head_dim) -> (Tensor, Tensor, Tensor)");
m.impl("qkv_proj_with_rope_fused_weight", torch::kCPU, &qkv_proj_with_rope_fused_weight);
// shared expert // shared expert
m.def( m.def(
......
...@@ -30,6 +30,22 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize ...@@ -30,6 +30,22 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) #define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
// this doesn't handle NaN.
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
const __m512i x = _mm512_cvtepu8_epi16(fp8_vec);
const __m512i mant = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x07)), 4);
const __m512i raw_exp = _mm512_srli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x78)), 3);
const __m512i exp = _mm512_slli_epi16(_mm512_add_epi16(raw_exp, _mm512_set1_epi16(120)), 7);
const __m512i nonsign = _mm512_or_si512(exp, mant);
const __m512i sign = _mm512_slli_epi16(_mm512_and_si512(x, _mm512_set1_epi16(0x80)), 8);
const __m512i combined = _mm512_or_si512(nonsign, sign);
const __mmask32 is_nonzero = _mm512_cmpneq_epi16_mask(x, _mm512_setzero_si512());
return (__m512bh)_mm512_maskz_mov_epi16(is_nonzero, combined);
}
inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) { inline __m512bh cvt_e4m3_bf16_intrinsic_without_denorm(__m256i fp8_vec) {
// The following conversion is without denorm behavior, that is to say, // The following conversion is without denorm behavior, that is to say,
// Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6) // Max subnorm : S.0000.111 = 0.875 ∗ 2**(−6)
...@@ -84,7 +100,7 @@ inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) { ...@@ -84,7 +100,7 @@ inline __m512bh cvt_e4m3_bf16_intrinsic_with_denorm(__m256i fp8_vec) {
inline __m512bh CVT_FP8_TO_BF16(__m256i a) { inline __m512bh CVT_FP8_TO_BF16(__m256i a) {
#ifdef SGLANG_CPU_FP8_CVT_FTZ #ifdef SGLANG_CPU_FP8_CVT_FTZ
return cvt_e4m3_bf16_intrinsic_without_denorm(a); return cvt_e4m3_bf16_intrinsic_no_nan(a);
#else #else
return cvt_e4m3_bf16_intrinsic_with_denorm(a); return cvt_e4m3_bf16_intrinsic_with_denorm(a);
#endif #endif
...@@ -172,4 +188,102 @@ inline void quantize_row_int8<at::BFloat16>( ...@@ -172,4 +188,102 @@ inline void quantize_row_int8<at::BFloat16>(
} }
#endif #endif
// transpose utils
// taken from my PR in ggml: https://github.com/ggml-org/llama.cpp/pull/8998
#if defined(CPU_CAPABILITY_AVX512)
inline void transpose_16x16_32bit(__m512i* v) {
__m512i v1[16];
v1[0] = _mm512_unpacklo_epi32(v[0], v[1]);
v1[1] = _mm512_unpackhi_epi32(v[0], v[1]);
v1[2] = _mm512_unpacklo_epi32(v[2], v[3]);
v1[3] = _mm512_unpackhi_epi32(v[2], v[3]);
v1[4] = _mm512_unpacklo_epi32(v[4], v[5]);
v1[5] = _mm512_unpackhi_epi32(v[4], v[5]);
v1[6] = _mm512_unpacklo_epi32(v[6], v[7]);
v1[7] = _mm512_unpackhi_epi32(v[6], v[7]);
v1[8] = _mm512_unpacklo_epi32(v[8], v[9]);
v1[9] = _mm512_unpackhi_epi32(v[8], v[9]);
v1[10] = _mm512_unpacklo_epi32(v[10], v[11]);
v1[11] = _mm512_unpackhi_epi32(v[10], v[11]);
v1[12] = _mm512_unpacklo_epi32(v[12], v[13]);
v1[13] = _mm512_unpackhi_epi32(v[12], v[13]);
v1[14] = _mm512_unpacklo_epi32(v[14], v[15]);
v1[15] = _mm512_unpackhi_epi32(v[14], v[15]);
v[0] = _mm512_unpacklo_epi64(v1[0], v1[2]);
v[1] = _mm512_unpackhi_epi64(v1[0], v1[2]);
v[2] = _mm512_unpacklo_epi64(v1[1], v1[3]);
v[3] = _mm512_unpackhi_epi64(v1[1], v1[3]);
v[4] = _mm512_unpacklo_epi64(v1[4], v1[6]);
v[5] = _mm512_unpackhi_epi64(v1[4], v1[6]);
v[6] = _mm512_unpacklo_epi64(v1[5], v1[7]);
v[7] = _mm512_unpackhi_epi64(v1[5], v1[7]);
v[8] = _mm512_unpacklo_epi64(v1[8], v1[10]);
v[9] = _mm512_unpackhi_epi64(v1[8], v1[10]);
v[10] = _mm512_unpacklo_epi64(v1[9], v1[11]);
v[11] = _mm512_unpackhi_epi64(v1[9], v1[11]);
v[12] = _mm512_unpacklo_epi64(v1[12], v1[14]);
v[13] = _mm512_unpackhi_epi64(v1[12], v1[14]);
v[14] = _mm512_unpacklo_epi64(v1[13], v1[15]);
v[15] = _mm512_unpackhi_epi64(v1[13], v1[15]);
v1[0] = _mm512_shuffle_i32x4(v[0], v[4], 0x88);
v1[1] = _mm512_shuffle_i32x4(v[1], v[5], 0x88);
v1[2] = _mm512_shuffle_i32x4(v[2], v[6], 0x88);
v1[3] = _mm512_shuffle_i32x4(v[3], v[7], 0x88);
v1[4] = _mm512_shuffle_i32x4(v[0], v[4], 0xdd);
v1[5] = _mm512_shuffle_i32x4(v[1], v[5], 0xdd);
v1[6] = _mm512_shuffle_i32x4(v[2], v[6], 0xdd);
v1[7] = _mm512_shuffle_i32x4(v[3], v[7], 0xdd);
v1[8] = _mm512_shuffle_i32x4(v[8], v[12], 0x88);
v1[9] = _mm512_shuffle_i32x4(v[9], v[13], 0x88);
v1[10] = _mm512_shuffle_i32x4(v[10], v[14], 0x88);
v1[11] = _mm512_shuffle_i32x4(v[11], v[15], 0x88);
v1[12] = _mm512_shuffle_i32x4(v[8], v[12], 0xdd);
v1[13] = _mm512_shuffle_i32x4(v[9], v[13], 0xdd);
v1[14] = _mm512_shuffle_i32x4(v[10], v[14], 0xdd);
v1[15] = _mm512_shuffle_i32x4(v[11], v[15], 0xdd);
v[0] = _mm512_shuffle_i32x4(v1[0], v1[8], 0x88);
v[1] = _mm512_shuffle_i32x4(v1[1], v1[9], 0x88);
v[2] = _mm512_shuffle_i32x4(v1[2], v1[10], 0x88);
v[3] = _mm512_shuffle_i32x4(v1[3], v1[11], 0x88);
v[4] = _mm512_shuffle_i32x4(v1[4], v1[12], 0x88);
v[5] = _mm512_shuffle_i32x4(v1[5], v1[13], 0x88);
v[6] = _mm512_shuffle_i32x4(v1[6], v1[14], 0x88);
v[7] = _mm512_shuffle_i32x4(v1[7], v1[15], 0x88);
v[8] = _mm512_shuffle_i32x4(v1[0], v1[8], 0xdd);
v[9] = _mm512_shuffle_i32x4(v1[1], v1[9], 0xdd);
v[10] = _mm512_shuffle_i32x4(v1[2], v1[10], 0xdd);
v[11] = _mm512_shuffle_i32x4(v1[3], v1[11], 0xdd);
v[12] = _mm512_shuffle_i32x4(v1[4], v1[12], 0xdd);
v[13] = _mm512_shuffle_i32x4(v1[5], v1[13], 0xdd);
v[14] = _mm512_shuffle_i32x4(v1[6], v1[14], 0xdd);
v[15] = _mm512_shuffle_i32x4(v1[7], v1[15], 0xdd);
}
// remove warning : ignoring attributes on template argument ‘__m512i’ [-Wignored-attributes]
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wignored-attributes"
// transpose from [2, 32] to [32, 2]
inline std::tuple<__m512i, __m512i> transpose_2x32_16bit(__m512i r0, __m512i r1) {
// r0: {a0, a1, ..., a31}
// r1: {b0, b1, ..., b31}
//
// d0: {a0, b0, ..., a15, b15}
// d1: {a16, b16, ..., a31, b31}
//
__m512i d0 = _mm512_unpacklo_epi16(r0, r1);
__m512i d1 = _mm512_unpackhi_epi16(r0, r1);
r0 = _mm512_shuffle_i32x4(d0, d1, 0x88);
r1 = _mm512_shuffle_i32x4(d0, d1, 0xdd);
d0 = _mm512_shuffle_i32x4(r0, r1, 0x88);
d1 = _mm512_shuffle_i32x4(r0, r1, 0xdd);
return std::make_tuple(d0, d1);
}
#pragma GCC diagnostic pop
#endif
} // anonymous namespace } // anonymous namespace
import itertools
import unittest
import sgl_kernel
import torch
from torch.nn.functional import scaled_dot_product_attention
from utils import precision
from sglang.test.test_utils import CustomTestCase
class TestMLA(CustomTestCase):
def _run_sdpa_forward_decode(
self,
query: torch.Tensor,
output: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
key: torch.Tensor,
loc: torch.Tensor,
req_to_token: torch.Tensor,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
scaling=None,
enable_gqa=False,
causal=False,
):
# set kv cache
k_cache[loc] = key
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
query = query.movedim(0, query.dim() - 2)
start_q, start_kv = 0, 0
for seq_idx in range(seq_lens.shape[0]):
seq_len_q = 1
seq_len_kv = seq_lens[seq_idx]
end_q = start_q + seq_len_q
end_kv = start_kv + seq_len_kv
per_req_query = query[:, start_q:end_q, :]
# get key and value from cache. per_req_tokens contains the kv cache
# index for each token in the sequence.
req_pool_idx = req_pool_indices[seq_idx]
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
per_req_out = (
scaled_dot_product_attention(
per_req_query.unsqueeze(0),
per_req_key.unsqueeze(0),
per_req_value.unsqueeze(0),
enable_gqa=enable_gqa,
scale=scaling,
is_causal=causal,
)
.squeeze(0)
.movedim(query.dim() - 2, 0)
)
output[start_q:end_q, :, :] = per_req_out
start_q, start_kv = end_q, end_kv
return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, seq_len):
dtype = torch.bfloat16
total_tokens = B * seq_len
sm_scale = 1.0 / (D**0.5)
logit_cap = 0.0
num_kv_splits = 8
enable_gqa = H_Q != H_KV
# q represents the new token being generated, one per batch
q = torch.randn(B, H_Q, D, dtype=dtype)
# k_buffer and v_buffer represent all previous tokens
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype)
v_buffer = k_buffer.narrow(2, 0, D_V)
key = torch.randn(B, H_KV, D, dtype=dtype)
value = key.narrow(2, 0, D_V)
# make sure no duplicates in loc
loc = torch.randperm(total_tokens)[:B].to(torch.int64)
k_buffer2 = k_buffer.clone()
v_buffer2 = k_buffer2.narrow(2, 0, D_V)
# o will have the same shape as q
o = torch.zeros(B, H_Q, D_V, dtype=dtype)
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype)
req_to_token = torch.arange(total_tokens).reshape(B, seq_len).to(torch.int32)
b_req_idx = torch.arange(B).to(torch.int64)
b_seq_len = torch.full((B,), seq_len).to(torch.int64)
attn_logits = torch.empty(
(B, H_Q, num_kv_splits, D_V + 1),
dtype=torch.float32,
)
torch.ops.sgl_kernel.decode_attention_cpu(
q,
k_buffer2,
v_buffer2,
o,
key,
value,
loc,
attn_logits,
req_to_token,
b_req_idx,
b_seq_len,
sm_scale,
logit_cap,
)
self._run_sdpa_forward_decode(
q,
o_grouped,
k_buffer,
v_buffer,
key,
loc,
req_to_token,
b_req_idx,
b_seq_len,
scaling=sm_scale,
enable_gqa=enable_gqa,
)
cos_sim = torch.nn.functional.cosine_similarity(
o.flatten(), o_grouped.flatten(), dim=0
)
atol = rtol = precision[q.dtype]
self.assertGreater(cos_sim.item(), 0.99)
torch.testing.assert_close(o, o_grouped, atol=atol, rtol=rtol)
torch.testing.assert_close(k_buffer, k_buffer2, atol=atol, rtol=rtol)
torch.testing.assert_close(v_buffer, v_buffer2, atol=atol, rtol=rtol)
def test_grouped_decode_attention(self):
configs = [
(1, 22, 1, 576, 512, 8 * 111),
(4, 22, 1, 576, 512, 8 * 128),
(40, 22, 1, 576, 512, 8 * 133),
]
for B, H_Q, H_KV, D, D_V, seqlen in configs:
self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V, seqlen)
if __name__ == "__main__":
unittest.main()
...@@ -33,7 +33,7 @@ def fused_moe(a, w1, w2, score, topk, renormalize, prepack): ...@@ -33,7 +33,7 @@ def fused_moe(a, w1, w2, score, topk, renormalize, prepack):
topk_weights = torch.empty(B, topk, dtype=torch.float32) topk_weights = torch.empty(B, topk, dtype=torch.float32)
topk_ids = torch.empty(B, topk, dtype=torch.int32) topk_ids = torch.empty(B, topk, dtype=torch.int32)
topk_weights, topk_ids = kernel.grouped_topk_cpu( topk_weights, topk_ids = kernel.grouped_topk_cpu(
a, score, topk, renormalize, G, topk_group a, score, topk, renormalize, G, topk_group, 0, None, None
) )
packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1 packed_w1 = kernel.convert_weight_packed(w1) if prepack else w1
......
...@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union ...@@ -4,7 +4,7 @@ from typing import Optional, Tuple, Union
import sgl_kernel import sgl_kernel
import torch import torch
from utils import precision from utils import make_non_contiguous, precision
from sglang.test.test_utils import CustomTestCase from sglang.test.test_utils import CustomTestCase
...@@ -38,6 +38,7 @@ class TestNorm(CustomTestCase): ...@@ -38,6 +38,7 @@ class TestNorm(CustomTestCase):
def _norm_test(self, m, n, dtype): def _norm_test(self, m, n, dtype):
x = torch.randn([m, n], dtype=dtype) x = torch.randn([m, n], dtype=dtype)
x = make_non_contiguous(x)
hidden_size = x.size(-1) hidden_size = x.size(-1)
weight = torch.randn(hidden_size, dtype=dtype) weight = torch.randn(hidden_size, dtype=dtype)
variance_epsilon = 1e-6 variance_epsilon = 1e-6
...@@ -49,7 +50,7 @@ class TestNorm(CustomTestCase): ...@@ -49,7 +50,7 @@ class TestNorm(CustomTestCase):
self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(ref_out, out, atol=atol, rtol=rtol))
ref_x = x.clone() ref_x = x.clone()
residual = torch.randn([m, n], dtype=dtype) residual = torch.randn([m, hidden_size], dtype=dtype)
ref_residual = residual.clone() ref_residual = residual.clone()
torch.ops.sgl_kernel.fused_add_rmsnorm_cpu( torch.ops.sgl_kernel.fused_add_rmsnorm_cpu(
......
...@@ -14,6 +14,7 @@ from sglang.test.test_utils import CustomTestCase ...@@ -14,6 +14,7 @@ from sglang.test.test_utils import CustomTestCase
convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed convert_weight_packed = torch.ops.sgl_kernel.convert_weight_packed
qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope qkv_proj_with_rope = torch.ops.sgl_kernel.qkv_proj_with_rope
qkv_proj_with_rope_fused_weight = torch.ops.sgl_kernel.qkv_proj_with_rope_fused_weight
torch.manual_seed(0) torch.manual_seed(0)
# constants # constants
kv_lora_rank = 512 kv_lora_rank = 512
...@@ -148,6 +149,7 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -148,6 +149,7 @@ class TestQKVProjWithROPE(CustomTestCase):
kv_a_proj_weight = ( kv_a_proj_weight = (
torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1 torch.randn(kv_lora_rank + qk_rope_head_dim, hidden_size, dtype=dtype) * 0.1
) )
fused_weight = torch.cat([q_a_proj_weight, kv_a_proj_weight], dim=0)
norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype) norm_weight2 = torch.randn(kv_lora_rank, dtype=dtype)
pos = torch.randint(10, 100, (B,)) pos = torch.randint(10, 100, (B,))
cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype) cos_sin_cache = torch.randn(100, rotary_dim, dtype=dtype)
...@@ -167,6 +169,7 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -167,6 +169,7 @@ class TestQKVProjWithROPE(CustomTestCase):
qb_packed = convert_weight_packed(q_b_proj_weight) qb_packed = convert_weight_packed(q_b_proj_weight)
kva_packed = convert_weight_packed(kv_a_proj_weight) kva_packed = convert_weight_packed(kv_a_proj_weight)
wkc_packed = convert_weight_packed(w_kc) wkc_packed = convert_weight_packed(w_kc)
fused_weight_packed = convert_weight_packed(fused_weight)
q_out, k_out, v_out = qkv_proj_with_rope( q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states, hidden_states,
...@@ -187,10 +190,33 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -187,10 +190,33 @@ class TestQKVProjWithROPE(CustomTestCase):
True, True,
None, None,
) )
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
qb_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
False,
None,
None,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype] atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fused_q_out, q_out))
self.assertTrue(torch.allclose(fused_k_out, k_out))
self.assertTrue(torch.allclose(fused_v_out, v_out))
def test_int8_qkv_proj_with_rope(self): def test_int8_qkv_proj_with_rope(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -252,10 +278,36 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -252,10 +278,36 @@ class TestQKVProjWithROPE(CustomTestCase):
True, True,
None, None,
) )
fused_weight = torch.cat([w1_q, w3_q], dim=0)
fused_weight_s = torch.cat([w1_s, w3_s], dim=0)
w_fused_q_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
w_fused_q_packed,
w2_q_packed,
wkc_packed,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
True,
False,
fused_weight_s,
w2_s,
True,
None,
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype] atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol))
self.assertTrue(torch.allclose(fused_q_out, q_out))
self.assertTrue(torch.allclose(fused_k_out, k_out))
self.assertTrue(torch.allclose(fused_v_out, v_out))
def test_fp8_qkv_proj_with_rope(self): def test_fp8_qkv_proj_with_rope(self):
dtype = torch.bfloat16 dtype = torch.bfloat16
...@@ -311,17 +363,17 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -311,17 +363,17 @@ class TestQKVProjWithROPE(CustomTestCase):
pos, pos,
cos_sin_cache, cos_sin_cache,
) )
fp8_q_a_proj_weight = convert_weight_packed(fp8_q_a_proj_weight) fp8_q_a_proj_weight_packed = convert_weight_packed(fp8_q_a_proj_weight)
fp8_q_b_proj_weight = convert_weight_packed(fp8_q_b_proj_weight) fp8_q_b_proj_weight_packed = convert_weight_packed(fp8_q_b_proj_weight)
fp8_kv_a_proj_with_mqa_weight = convert_weight_packed( fp8_kv_a_proj_with_mqa_weight_packed = convert_weight_packed(
fp8_kv_a_proj_with_mqa_weight fp8_kv_a_proj_with_mqa_weight
) )
w_kc = convert_weight_packed(w_kc) w_kc = convert_weight_packed(w_kc)
q_out, k_out, v_out = qkv_proj_with_rope( q_out, k_out, v_out = qkv_proj_with_rope(
hidden_states, hidden_states,
fp8_q_a_proj_weight, fp8_q_a_proj_weight_packed,
fp8_q_b_proj_weight, fp8_q_b_proj_weight_packed,
fp8_kv_a_proj_with_mqa_weight, fp8_kv_a_proj_with_mqa_weight_packed,
w_kc, w_kc,
norm_weight1, norm_weight1,
norm_weight2, norm_weight2,
...@@ -336,10 +388,44 @@ class TestQKVProjWithROPE(CustomTestCase): ...@@ -336,10 +388,44 @@ class TestQKVProjWithROPE(CustomTestCase):
True, True,
[scale_block_size_N, scale_block_size_K], [scale_block_size_N, scale_block_size_K],
) )
fused_weight = torch.cat(
[fp8_q_a_proj_weight, fp8_kv_a_proj_with_mqa_weight], dim=0
)
fused_weight_s = torch.cat(
[q_a_proj_weight_scale_inv, kv_a_proj_with_mqa_weight_scale_inv], dim=0
)
fused_weight_packed = convert_weight_packed(fused_weight)
fused_q_out, fused_k_out, fused_v_out = qkv_proj_with_rope_fused_weight(
hidden_states,
fused_weight_packed,
fp8_q_b_proj_weight_packed,
w_kc,
norm_weight1,
norm_weight2,
pos,
cos_sin_cache,
eps,
False,
True,
fused_weight_s.float(),
q_b_proj_weight_scale_inv.float(),
True,
[scale_block_size_N, scale_block_size_K],
q_lora_rank,
kv_lora_rank,
qk_rope_head_dim,
)
atol = rtol = precision[q_ref.dtype] atol = rtol = precision[q_ref.dtype]
self.assertTrue(torch.allclose(q_ref, q_out, atol=atol, rtol=rtol)) # Due to the change in multiplication order, the error is amplified.
self.assertTrue(torch.allclose(k_ref, k_out, atol=atol, rtol=rtol)) # In the model, with fewer layers, this doesn't cause issues, but in
self.assertTrue(torch.allclose(v_ref, v_out, atol=atol, rtol=rtol)) # tests with more layers, we need to enlarge the tolerance to pass the tests.
torch.testing.assert_close(q_ref, q_out, atol=1e-1, rtol=1e-1)
torch.testing.assert_close(k_ref, k_out, atol=atol, rtol=rtol)
torch.testing.assert_close(v_ref, v_out, atol=atol, rtol=rtol)
torch.testing.assert_close(fused_q_out, q_out)
torch.testing.assert_close(fused_k_out, k_out)
torch.testing.assert_close(fused_v_out, v_out)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -34,7 +34,15 @@ class TestGroupedTopK(CustomTestCase): ...@@ -34,7 +34,15 @@ class TestGroupedTopK(CustomTestCase):
# fused version # fused version
topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu( topk_weights, topk_ids = torch.ops.sgl_kernel.grouped_topk_cpu(
hidden_states, gating_output, topk, renormalize, G, topk_group hidden_states,
gating_output,
topk,
renormalize,
G,
topk_group,
0,
None,
None,
) )
res = torch.zeros(M, E, dtype=torch.float) res = torch.zeros(M, E, dtype=torch.float)
...@@ -83,6 +91,9 @@ class TestBiasedGroupedTopK(CustomTestCase): ...@@ -83,6 +91,9 @@ class TestBiasedGroupedTopK(CustomTestCase):
renormalize, renormalize,
G, G,
topk_group, topk_group,
0,
None,
None,
) )
res = torch.zeros(M, E, dtype=torch.float) res = torch.zeros(M, E, dtype=torch.float)
......
...@@ -244,3 +244,11 @@ def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk): ...@@ -244,3 +244,11 @@ def native_fp8_fused_moe(a, w1, w2, topk_weight, topk_ids, topk):
.sum(dim=1) .sum(dim=1)
.to(a.dtype) .to(a.dtype)
) )
def make_non_contiguous(x: torch.Tensor) -> torch.Tensor:
"""
Make a tensor non-contiguous by slicing it via last dimension.
"""
last_dim = x.shape[-1]
return x[..., : last_dim // 2] if x.is_contiguous() else x
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment