Unverified Commit a73c4df4 authored by Ma Mingfei's avatar Ma Mingfei Committed by GitHub
Browse files

Add optimized native kernels in sgl-kernel (#5150)


Co-authored-by: default avatarChunyuan WU <chunyuan.wu@intel.com>
Co-authored-by: default avatarYanbingJiang <yanbing.jiang@intel.com>
Co-authored-by: default avatarblzheng <beilei.zheng@intel.com>
parent 89a55418
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t, typename func_t, typename vec_func_t>
void act_and_mul_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t num_tokens,
int64_t dim,
const func_t& f,
const vec_func_t& vf) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t kVecSize = bVec::size();
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
scalar_t* __restrict__ output_ptr = output + i * dim;
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec y_bvec = bVec::loadu(input_other_ptr + d);
fVec y_fvec0, y_fvec1;
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
x_fvec0 = vf(x_fvec0);
x_fvec1 = vf(x_fvec1);
x_fvec0 = x_fvec0 * y_fvec0;
x_fvec1 = x_fvec1 * y_fvec1;
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(output_ptr + d);
}
#pragma GCC unroll 4
for (; d < dim; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float y_val = static_cast<float>(input_other_ptr[d]);
output_ptr[d] = f(x_val) * y_val;
}
}
});
}
} // anonymous namespace
// input : {num_tokens, 2 * d}
// output : {num_tokens, d}
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[](float x) { return x / (1.f + std::exp(-x)); },
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
});
return out;
}
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
void bmm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
int64_t B,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideB,
int64_t mat1_strideM,
int64_t out_strideB,
int64_t out_strideM,
float scale = 0.f) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// mat2 contiguous in [B, N, K]
int64_t mat2_strideB = N * K;
int64_t mat2_strideN = K;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [B, MB, NB]
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, mb{0}, nb{0};
data_index_init(begin, bs, B, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t>(
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
// move to the next index
data_index_step(bs, B, mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
// mat1 : [B, M, K]
// mat2 : [B, N, K] or [B, OC, IC]
// out : [B, M, N]
// scale: [] 0-dim tensor for per tensor quant
//
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) {
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
// input and out could be non-contiguous
// weight needs to be contiguous in [OC, IC] order
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
CHECK_INPUT(mat2);
CHECK_DIM(3, out);
CHECK_DIM(3, mat1);
CHECK_DIM(3, mat2);
int64_t B = mat1.size(0);
int64_t M = mat1.size(1);
int64_t N = mat2.size(1);
int64_t K = mat1.size(2);
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
int64_t mat1_strideB = mat1.stride(0);
int64_t mat1_strideM = mat1.stride(1);
int64_t out_strideB = out.stride(0);
int64_t out_strideM = out.stride(1);
// check shapes
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
bmm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
B,
M,
N,
K,
mat1_strideB,
mat1_strideM,
out_strideB,
out_strideM);
});
}
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/record_function.h>
#if defined(_OPENMP)
#include <omp.h>
#endif
namespace {
// dispatch bool
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
[&] { \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME = false; \
return __VA_ARGS__(); \
} \
}()
// dispatch: bfloat16, float16, int8_t
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::BFloat16: { \
using packed_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using packed_t = at::Half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Char: { \
using packed_t = int8_t; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
}()
#define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
#define CHECK_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CPU(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// parallel routines
constexpr int GRAIN_SIZE = 1024;
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) {
return (x + y - 1) / y;
}
template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0
// onednn partition pattern
T& n_my = n_end;
if (nth <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else {
T n1 = div_up(n, nth);
T n2 = n1 - 1;
T T1 = n - n2 * nth;
n_my = ith < T1 ? n1 : n2;
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
}
n_end += n_start;
#else
// pytorch aten partition pattern
T n_my = div_up(n, nth);
n_start = ith * n_my;
n_end = std::min(n_start + n_my, n);
#endif
}
template <typename func_t>
inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP)
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
#endif
}
// data indexing for dimension collapse
template <typename T>
inline T data_index_init(T offset) {
return offset;
}
template <typename T, typename... Args>
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
offset = data_index_init(offset, std::forward<Args>(args)...);
x = offset % X;
return offset / X;
}
inline bool data_index_step() {
return true;
}
template <typename T, typename... Args>
inline bool data_index_step(T& x, const T& X, Args&&... args) {
if (data_index_step(std::forward<Args>(args)...)) {
x = ((x + 1) == X) ? 0 : (x + 1);
return x == 0;
}
return false;
}
// forced unroll for perf critical path
#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif
template <int n>
struct Unroll {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};
template <>
struct Unroll<1> {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
f(std::integral_constant<int, 0>{}, args...);
}
};
} // anonymous namespace
#include "common.h"
#include "vec.h"
namespace {
// [NOTE] TODO list for this kernel:
// 1. tune the value for BLOCK_N
// 2. planning for {batches, num_heads, num_kv_splits}
// and use actual num_kv_splits for small seq length
// 3. try fast impl of `.tanh()`
// 4. provide amx kernel for index_gemm_kernel_nn when M = 16
//
inline void fill_stub(float* __restrict__ out, float val, int64_t size) {
using Vec = at::vec::Vectorized<float>;
const Vec data_vec(val);
at::vec::map<float>([data_vec](Vec out) { return out = data_vec; }, out, out, size);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec s_fvec = fVec(s);
int64_t d = 0;
for (; d <= size - bVec::size(); d += bVec::size()) {
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(acc[d] * s);
}
}
// GEMM handles query @ key (indexed) x scale
// A : [M, K]
// B : [N, K] indexed
// C : [M, N]
//
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nt {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
float scale,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t K,
int64_t max_tokens) {
for (int64_t m = 0; m < BLOCK_M; ++m) {
for (int64_t n = 0; n < BLOCK_N; ++n) {
float sum = 0.f;
int64_t b_idx = indices[n];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
for (int64_t k = 0; k < K; ++k) {
sum += scale * static_cast<float>(A[m * lda + k]) * static_cast<float>(B[b_idx * ldb + k]);
}
C[m * ldc + n] = sum;
}
}
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <typename index_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::BFloat16* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
float scale,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t K,
int64_t max_tokens) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vscale = _mm512_set1_ps(scale);
auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); };
Unroll<ROWS * COLS>{}(loadc);
// for main loop
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_loadu_si512(A + row * lda + k));
}
if constexpr (row == 0) {
if constexpr (col + 1 < COLS) {
int64_t b_idx_prefetch = indices[col + 1];
_mm_prefetch(B + b_idx_prefetch * ldb + k, _MM_HINT_T0);
}
int64_t b_idx = indices[col];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
vb[col] = (__m512bh)(_mm512_loadu_si512(B + b_idx * ldb + k));
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
// for remainder
auto compute2 = [&](auto i, int64_t k, __mmask32 mask) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_maskz_loadu_epi16(mask, A + row * lda + k));
}
if constexpr (row == 0) {
int64_t b_idx = indices[col];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
vb[col] = (__m512bh)(_mm512_maskz_loadu_epi16(mask, B + b_idx * ldb + k));
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
int64_t k = 0;
for (; k <= K - 32; k += 32) {
Unroll<ROWS * COLS>{}(compute, k);
}
int64_t count = K - k;
if (count > 0) {
__mmask32 mask = (1ULL << count) - 1;
Unroll<ROWS * COLS>{}(compute2, k, mask);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale));
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
// this is used when N isn't multiple of 16,
// N corresponds to `head_size_v` which should be 16x
template <typename scalar_t, typename index_t>
inline void tinygemm_kernel_nn_scalar(
const float* __restrict__ A,
const scalar_t* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
const float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t max_tokens) {
for (int64_t m = 0; m < M; ++m) {
for (int64_t n = 0; n < N; ++n) {
C[m * ldc + n] *= scale[m];
for (int64_t k = 0; k < K; ++k) {
int64_t b_idx = indices[k];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
C[m * ldc + n] += A[m * lda + k] * static_cast<float>(B[b_idx * ldb + n]);
}
}
}
}
// GEMM handles v' * scale + attn @ value (indexed)
// A : [M, K]
// B : [K, N] indexed
// C :[M, N]
//
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const float* __restrict__ A,
const scalar_t* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
const float* __restrict__ scale,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t K,
int64_t max_tokens) {
tinygemm_kernel_nn_scalar(A, B, C, indices, scale, BLOCK_M, BLOCK_N, K, lda, ldb, ldc, max_tokens);
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <typename index_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
static inline void apply(
const float* __restrict__ A,
const at::BFloat16* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
const float* __restrict__ scale,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t K,
int64_t max_tokens) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
__m512 va;
__m512 vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vscale;
auto loadc = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
if constexpr (col == 0) {
vscale = _mm512_set1_ps(scale[row]);
}
#pragma GCC diagnostic pop
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
vc[i] = _mm512_mul_ps(vc[i], vscale);
};
Unroll<ROWS * COLS>{}(loadc);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_ps(A[row * lda + k]);
}
if constexpr (row == 0) {
if (k + 1 < K) {
int64_t b_idx_prefetch = indices[k + 1];
_mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0);
}
int64_t b_idx = indices[k];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
// for COLS = 2, 4, 6, 8 use 512 bit load
// for COLS = 1, 3, 5, 7 use 256 bit load
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
__m512i b16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(B + b_idx * ldb + col * 16));
vb[col + 0] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
vb[col + 1] = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
}
} else {
__m256i b16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(B + b_idx * ldb + col * 16));
vb[col] = CVT_BF16_TO_FP32(b16);
}
}
vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]);
};
for (int64_t k = 0; k < K; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start, \
C + mb_start * ldc + nb_start, \
indices, \
scale + mb_start, \
lda, \
ldb, \
ldc, \
K, \
max_tokens);
template <typename scalar_t, typename index_t>
void index_gemm_kernel_nt(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
float scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t max_tokens) {
// pattern: 1-8-8
if (M == 1) {
constexpr int64_t BLOCK_N = 8;
const int64_t NB = div_up(N, BLOCK_N);
int64_t mb_start = 0, lda = 1, ldc = 1;
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (nb_size) {
case 1:
LAUNCH_TINYGEMM_KERNEL_NT(1, 1);
break;
case 2:
LAUNCH_TINYGEMM_KERNEL_NT(1, 2);
break;
case 3:
LAUNCH_TINYGEMM_KERNEL_NT(1, 3);
break;
case 4:
LAUNCH_TINYGEMM_KERNEL_NT(1, 4);
break;
case 5:
LAUNCH_TINYGEMM_KERNEL_NT(1, 5);
break;
case 6:
LAUNCH_TINYGEMM_KERNEL_NT(1, 6);
break;
case 7:
LAUNCH_TINYGEMM_KERNEL_NT(1, 7);
break;
case 8:
LAUNCH_TINYGEMM_KERNEL_NT(1, 8);
break;
default:
TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size");
}
}
return;
}
// pattern: 1-6-24
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 6;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size) {
// mb_size = 1
case 0x11:
LAUNCH_TINYGEMM_KERNEL_NT(1, 1);
break;
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NT(1, 2);
break;
case 0x13:
LAUNCH_TINYGEMM_KERNEL_NT(1, 3);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NT(1, 4);
break;
case 0x15:
LAUNCH_TINYGEMM_KERNEL_NT(1, 5);
break;
case 0x16:
LAUNCH_TINYGEMM_KERNEL_NT(1, 6);
break;
// mb_size = 2
case 0x21:
LAUNCH_TINYGEMM_KERNEL_NT(2, 1);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NT(2, 2);
break;
case 0x23:
LAUNCH_TINYGEMM_KERNEL_NT(2, 3);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NT(2, 4);
break;
case 0x25:
LAUNCH_TINYGEMM_KERNEL_NT(2, 5);
break;
case 0x26:
LAUNCH_TINYGEMM_KERNEL_NT(2, 6);
break;
// mb_size = 3
case 0x31:
LAUNCH_TINYGEMM_KERNEL_NT(3, 1);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NT(3, 2);
break;
case 0x33:
LAUNCH_TINYGEMM_KERNEL_NT(3, 3);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NT(3, 4);
break;
case 0x35:
LAUNCH_TINYGEMM_KERNEL_NT(3, 5);
break;
case 0x36:
LAUNCH_TINYGEMM_KERNEL_NT(3, 6);
break;
// mb_size = 4
case 0x41:
LAUNCH_TINYGEMM_KERNEL_NT(4, 1);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NT(4, 2);
break;
case 0x43:
LAUNCH_TINYGEMM_KERNEL_NT(4, 3);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NT(4, 4);
break;
case 0x45:
LAUNCH_TINYGEMM_KERNEL_NT(4, 5);
break;
case 0x46:
LAUNCH_TINYGEMM_KERNEL_NT(4, 6);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t, typename index_t>
void index_gemm_kernel_nn(
const float* __restrict__ A,
const scalar_t* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
float* __restrict__ scale,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t max_tokens) {
constexpr int kVecSize = 16;
if ((N & (kVecSize - 1)) != 0) {
tinygemm_kernel_nn_scalar(A, B, C, indices, scale, M, N, K, lda, ldb, ldc, max_tokens);
return;
}
// pattern: 1-8-8
if (M == 1) {
constexpr int64_t BLOCK_N = 8 * kVecSize;
const int64_t NB = div_up(N, BLOCK_N);
int64_t mb_start = 0, lda = 1, ldc = 1;
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (nb_size >> 4) {
case 1:
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
break;
case 2:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 3:
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
break;
case 4:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
case 5:
LAUNCH_TINYGEMM_KERNEL_NN(1, 80);
break;
case 6:
LAUNCH_TINYGEMM_KERNEL_NN(1, 96);
break;
case 7:
LAUNCH_TINYGEMM_KERNEL_NN(1, 112);
break;
case 8:
LAUNCH_TINYGEMM_KERNEL_NN(1, 128);
break;
default:
TORCH_CHECK(false, "Unexpected block size, 1x", "nb_size");
}
}
return;
}
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 6 * kVecSize;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x11:
LAUNCH_TINYGEMM_KERNEL_NN(1, 16);
break;
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x13:
LAUNCH_TINYGEMM_KERNEL_NN(1, 48);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
case 0x15:
LAUNCH_TINYGEMM_KERNEL_NN(1, 80);
break;
case 0x16:
LAUNCH_TINYGEMM_KERNEL_NN(1, 96);
break;
// mb_size = 2
case 0x21:
LAUNCH_TINYGEMM_KERNEL_NN(2, 16);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x23:
LAUNCH_TINYGEMM_KERNEL_NN(2, 48);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
case 0x25:
LAUNCH_TINYGEMM_KERNEL_NN(2, 80);
break;
case 0x26:
LAUNCH_TINYGEMM_KERNEL_NN(2, 96);
break;
// mb_size = 3
case 0x31:
LAUNCH_TINYGEMM_KERNEL_NN(3, 16);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x33:
LAUNCH_TINYGEMM_KERNEL_NN(3, 48);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
case 0x35:
LAUNCH_TINYGEMM_KERNEL_NN(3, 80);
break;
case 0x36:
LAUNCH_TINYGEMM_KERNEL_NN(3, 96);
break;
// mb_size = 4
case 0x41:
LAUNCH_TINYGEMM_KERNEL_NN(4, 16);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x43:
LAUNCH_TINYGEMM_KERNEL_NN(4, 48);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
case 0x45:
LAUNCH_TINYGEMM_KERNEL_NN(4, 80);
break;
case 0x46:
LAUNCH_TINYGEMM_KERNEL_NN(4, 96);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t, typename index_t>
void decode_attention_kernel_impl(
scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
const scalar_t* __restrict__ query,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
int64_t batches,
int64_t num_heads,
int64_t head_size,
int64_t head_size_v,
int64_t num_kv_splits,
int64_t k_strideN,
int64_t k_strideH,
int64_t v_strideN,
int64_t v_strideH,
float scaling,
float logit_cap,
int64_t max_num_reqs,
int64_t max_context_len,
int64_t max_total_num_tokens) {
using Vec = at::vec::Vectorized<float>;
// block length for k_buffer and v_buffer
constexpr int64_t BLOCK_N = 256;
// strides
const int64_t q_strideM = num_heads * head_size;
const int64_t q_strideH = head_size;
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
const int64_t l_stride2 = head_size_v + 1;
const bool has_logit_cap = logit_cap > 0;
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
// parallel on [batches, num_heads, num_kv_splits]
at::parallel_for(0, batches * num_heads * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_id{0}, kv_id{0};
data_index_init(begin, bs, batches, head_id, num_heads, kv_id, num_kv_splits);
// s_prime and s_delta
alignas(64) float s_i[BLOCK_N];
float* __restrict__ s_delta = s_i;
for (int64_t i = begin; i < end; ++i) {
// get query
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + head_id * q_strideH;
// get key/value
int64_t seq_len_kv = seq_lens[bs];
int64_t req_pool_id = req_pool_indices[bs];
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits);
const int64_t kv_start = kv_id * SPLIT_SIZE;
const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv);
float m_prime = -std::numeric_limits<float>::infinity();
float s_prime = 0.f;
// get v_prime, and init to zero
float* __restrict__ v_prime = attn_logits + i * (head_size_v + 1);
fill_stub(v_prime, 0.f, head_size_v);
// loop over K and V sequence with BLOCK_N
for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) {
int64_t n_size = std::min(BLOCK_N, kv_end - n);
// calculate s_i <- scale * Q @ K
index_gemm_kernel_nt<scalar_t, index_t>(
/* A */ q_ptr,
/* B */ k_buffer + head_id * k_strideH,
/* C */ s_i,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* scl */ scaling,
/* M */ 1,
/* N */ n_size,
/* K */ head_size,
/* lda */ 1,
/* ldb */ k_strideN,
/* ldc */ 1,
/* mtt */ max_total_num_tokens);
// TODO: `tanh` from torch uses sleef u10, going to be slow
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i,
s_i,
n_size);
}
// m_i: max value per row
float m_i = at::vec::reduce_all<float>([](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i, n_size);
m_i = std::max(m_i, m_prime);
// m_delta <- exp(m' - m_i)
float m_delta = std::exp(m_prime - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>([m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta, s_i, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime *= m_delta;
s_prime += at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta, n_size);
m_prime = m_i;
// caculate V' <- s_delta @ V + V' * m_delta
index_gemm_kernel_nn<scalar_t, index_t>(
/* A */ s_delta,
/* B */ v_buffer + head_id * v_strideH,
/* C */ v_prime,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* scl */ &m_delta,
/* M */ 1,
/* N */ head_size_v,
/* K */ n_size,
/* lda */ 1,
/* ldb */ v_strideN,
/* ldc */ 1,
/* mtt */ max_total_num_tokens);
} // loop with KV blocks
// only update v' when kv_split_size > 0
if (kv_end > kv_start) {
float s = 1 / s_prime;
at::vec::map<float>([s](Vec out) { return out * Vec(s); }, v_prime, v_prime, head_size_v);
v_prime[head_size_v] = m_prime + std::log(s_prime);
}
// move to the next index
data_index_step(bs, batches, head_id, num_heads, kv_id, num_kv_splits);
}
});
// parallel on [batches, num_heads]
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
// NB: here we use logits[b][h][0] as acc, since
// for the first kv split (kv_id == 0):
// m_delta = std::exp(-inf) = 0
// e_logic = std::exp(0) = 1
// acc = acc * m_delta + tv * e_logic = tv
for (int64_t i = begin; i < end; ++i) {
float* __restrict__ acc = attn_logits + i * l_stride1;
float s_prime = 0.f;
float m_prime = -std::numeric_limits<scalar_t>::infinity();
// update acc with from each kv_split
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
float* __restrict__ tv = acc + kv_id * l_stride2;
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
float m_i = std::max(tlogic, m_prime);
float m_delta = std::exp(m_prime - m_i);
float e_logic = std::exp(tlogic - m_i);
if (kv_id != 0) {
at::vec::map2<float>(
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
acc,
acc,
tv,
head_size_v);
}
s_prime = s_prime * m_delta + e_logic;
m_prime = m_i;
}
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
}
});
}
template <typename scalar_t, typename index_t>
void decode_attention_grouped_kernel_impl(
scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
const scalar_t* __restrict__ query,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
int64_t batches,
int64_t num_heads,
int64_t num_heads_kv,
int64_t head_size,
int64_t head_size_v,
int64_t num_kv_splits,
int64_t k_strideN,
int64_t k_strideH,
int64_t v_strideN,
int64_t v_strideH,
float scaling,
float logit_cap,
int64_t max_num_reqs,
int64_t max_context_len,
int64_t max_total_num_tokens) {
using Vec = at::vec::Vectorized<float>;
// block length for k_buffer and v_buffer
constexpr int64_t BLOCK_N = 256;
// block length for heads
// we parallel on [batches, divup(num_heads, BLOCK_H), num_kv_splits]
// use smaller BLOCK_H when batches is small to utilize all cores
constexpr int64_t kBLOCK_H = 16;
const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H);
// strides
const int64_t q_strideM = num_heads * head_size;
const int64_t q_strideH = head_size;
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
const int64_t l_stride2 = head_size_v + 1;
const bool has_logit_cap = logit_cap > 0;
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
// partition the heads into blocks for parallel
const int64_t num_groups = num_heads / num_heads_kv;
const int64_t num_blocks = div_up(num_heads, std::min(BLOCK_H, num_groups));
const int64_t num_groups_per_block = div_up(num_groups, BLOCK_H);
const int64_t num_heads_per_block = std::min(num_groups, BLOCK_H);
// parallel on [batches, num_blocks, num_kv_splits]
at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_id{0}, kv_id{0};
data_index_init(begin, bs, batches, head_id, num_blocks, kv_id, num_kv_splits);
alignas(64) float s_i[BLOCK_H * BLOCK_N];
float* __restrict__ s_delta = s_i;
alignas(64) float s_prime[BLOCK_H];
alignas(64) float m_prime[BLOCK_H];
alignas(64) float m_delta[BLOCK_H];
for (int64_t i = begin; i < end; ++i) {
const int64_t h_start = head_id * num_heads_per_block;
const int64_t h_end = std::min(h_start + num_heads_per_block, num_heads);
const int64_t h_size = h_end - h_start;
// get query
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
// kv head id and valid block head size
int64_t head_kv_id = head_id / num_groups_per_block;
int64_t seq_len_kv = seq_lens[bs];
int64_t req_pool_id = req_pool_indices[bs];
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
const int64_t SPLIT_SIZE = div_up(seq_len_kv, num_kv_splits);
const int64_t kv_start = kv_id * SPLIT_SIZE;
const int64_t kv_end = std::min(kv_start + SPLIT_SIZE, seq_len_kv);
fill_stub(s_prime, 0.f, BLOCK_H);
fill_stub(m_prime, -std::numeric_limits<float>::infinity(), BLOCK_H);
// get v_prime, and init to zero
float* __restrict__ v_prime = attn_logits + bs * l_stride0 + h_start * l_stride1 + kv_id * l_stride2;
for (int64_t h = 0; h < h_size; ++h) {
fill_stub(v_prime + h * l_stride1, 0.f, head_size_v);
}
// loop over K and V sequence with BLOCK_N
for (int64_t n = kv_start; n < kv_end; n += BLOCK_N) {
int64_t n_size = std::min(BLOCK_N, kv_end - n);
// calculate Q @ K
index_gemm_kernel_nt<scalar_t, index_t>(
/* A */ q_ptr,
/* B */ k_buffer + head_kv_id * k_strideH,
/* C */ s_i,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* scl */ scaling,
/* M */ h_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideH,
/* ldb */ k_strideN,
/* ldc */ BLOCK_N,
/* mtt */ max_total_num_tokens);
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i,
s_i,
n_size);
}
// update the scaling coefficients
for (int64_t h = 0; h < h_size; ++h) {
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + h * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[h]);
// m_delta <- exp(m' - m_i)
m_delta[h] = std::exp(m_prime[h] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + h * BLOCK_N, s_i + h * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[h] *= m_delta[h];
s_prime[h] += at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + h * BLOCK_N, n_size);
m_prime[h] = m_i;
}
// caculate V' <- s_delta @ V + V' * m_delta
index_gemm_kernel_nn<scalar_t, index_t>(
/* A */ s_delta,
/* B */ v_buffer + head_kv_id * v_strideH,
/* C */ v_prime,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* scl */ m_delta,
/* M */ h_size,
/* N */ head_size_v,
/* K */ n_size,
/* lda */ BLOCK_N,
/* ldb */ v_strideN,
/* ldc */ l_stride1,
/* mtt */ max_total_num_tokens);
} // loop with KV blocks
// only update v' when kv_split_size > 0
if (kv_end > kv_start) {
for (int64_t h = 0; h < h_size; ++h) {
float s = 1 / s_prime[h];
at::vec::map<float>(
[s](Vec out) { return out * Vec(s); }, v_prime + h * l_stride1, v_prime + h * l_stride1, head_size_v);
(v_prime + h * l_stride1)[head_size_v] = m_prime[h] + std::log(s_prime[h]);
}
}
// move to the next index
data_index_step(bs, batches, head_id, num_blocks, kv_id, num_kv_splits);
}
});
// parallel on [batches, num_heads]
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
// NB: same as above
for (int64_t i = begin; i < end; ++i) {
float* __restrict__ acc = attn_logits + i * l_stride1;
float s_prime = 0.f;
float m_prime = -std::numeric_limits<scalar_t>::infinity();
// update acc with from each kv_split
for (int64_t kv_id = 0; kv_id < num_kv_splits; ++kv_id) {
float* __restrict__ tv = acc + kv_id * l_stride2;
const float tlogic = (acc + kv_id * l_stride2)[head_size_v];
float m_i = std::max(tlogic, m_prime);
float m_delta = std::exp(m_prime - m_i);
float e_logic = std::exp(tlogic - m_i);
if (kv_id != 0) {
at::vec::map2<float>(
[m_delta, e_logic](Vec x, Vec y) { return x * Vec(m_delta) + y * Vec(e_logic); },
acc,
acc,
tv,
head_size_v);
}
s_prime = s_prime * m_delta + e_logic;
m_prime = m_i;
}
copy_stub<scalar_t>(output + i * head_size_v, acc, 1 / s_prime, head_size_v);
}
});
}
} // anonymous namespace
// query: [num_tokens, num_heads, head_size]
// output: [num_tokens, num_heads, head_size]
// k_buffer: [max_total_num_tokens, num_heads, head_size]
// v_buffer: [max_total_num_tokens, num_heads, head_size_v]
// attn_logits: [num_seqs, num_heads, num_kv_splits, head_size_v + 1]
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
// req_pool_indices: [num_seqs] int64
// seq_lens: [num_seqs] int64
//
void decode_attention_cpu(
at::Tensor& query,
at::Tensor& output,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& attn_logits,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
double sm_scale,
double logit_cap) {
RECORD_FUNCTION(
"sgl-kernel::decode_attention_cpu",
std::vector<c10::IValue>(
{query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens}));
CHECK_INPUT(query);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
CHECK_DIM(3, query);
CHECK_DIM(3, k_buffer);
CHECK_DIM(3, v_buffer);
int64_t num_seqs = seq_lens.size(0);
int64_t max_num_reqs = req_to_token.size(0);
int64_t max_context_len = req_to_token.size(1);
int64_t max_total_num_tokens = k_buffer.size(0);
int64_t num_heads = query.size(1);
int64_t num_heads_kv = k_buffer.size(1);
int64_t head_size = query.size(2);
int64_t head_size_v = v_buffer.size(2);
int64_t num_kv_splits = attn_logits.size(2);
CHECK_EQ(attn_logits.size(0), num_seqs);
CHECK_EQ(attn_logits.size(1), num_heads);
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
CHECK_EQ(attn_logits.scalar_type(), at::kFloat);
// strides for k_buffer and v_buffer
int64_t k_strideN = k_buffer.stride(0);
int64_t k_strideH = k_buffer.stride(1);
int64_t v_strideN = v_buffer.stride(0);
int64_t v_strideH = v_buffer.stride(1);
// check index data types
const auto index_dtype = req_to_token.scalar_type();
TORCH_CHECK(
index_dtype == at::kInt || index_dtype == at::kLong,
"decode: expect req_to_token to be int32 or int64, got ",
index_dtype);
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "decode: expect req_lens to be int64, got ", seq_lens.scalar_type());
TORCH_CHECK(
req_pool_indices.scalar_type() == at::kLong,
"decode: expect req_pool_indices to be int64, got ",
req_pool_indices.scalar_type());
AT_DISPATCH_REDUCED_FLOATING_TYPES(query.scalar_type(), "decode_attention_kernel", [&] {
AT_DISPATCH_INDEX_TYPES(index_dtype, "decode_attention_indices", [&] {
if (num_heads == num_heads_kv) {
// MHA
decode_attention_kernel_impl<scalar_t, index_t>(
output.data_ptr<scalar_t>(),
attn_logits.data_ptr<float>(),
query.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(),
v_buffer.data_ptr<scalar_t>(),
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
num_seqs,
num_heads,
head_size,
head_size_v,
num_kv_splits,
k_strideN,
k_strideH,
v_strideN,
v_strideH,
sm_scale,
logit_cap,
max_num_reqs,
max_context_len,
max_total_num_tokens);
} else {
// GQA/MQA/MLA
decode_attention_grouped_kernel_impl<scalar_t, index_t>(
output.data_ptr<scalar_t>(),
attn_logits.data_ptr<float>(),
query.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(),
v_buffer.data_ptr<scalar_t>(),
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
num_seqs,
num_heads,
num_heads_kv,
head_size,
head_size_v,
num_kv_splits,
k_strideN,
k_strideH,
v_strideN,
v_strideH,
sm_scale,
logit_cap,
max_num_reqs,
max_context_len,
max_total_num_tokens);
}
});
});
}
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// [NOTE]: extend attention for CPU
// 1. tune BLOCK_M and BLOCK_N
// 2. can handle non-contiguous k_exttend and v_extend
// 3. computes attention for prefix and extend separately
// 4. TODO: vectorize `pack_vnni` and `pack_vnni2`
//
template <typename index_t>
inline index_t get_index(index_t* ind, int i) {
return (ind == nullptr) ? (index_t)i : ind[i];
}
// convert to vnni format
// from [N, K/2, 2] to [K/2, N, 2] for bfloat16 and float16
template <typename scalar_t, typename index_t>
void pack_vnni(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int N,
int K,
int ld_src,
int ld_dst) {
for (int n = 0; n < N; ++n) {
index_t index = get_index(ind, n);
for (int k = 0; k < K / 2; ++k) {
for (int d = 0; d < 2; ++d) {
dst[k * ld_dst * 2 + n * 2 + d] = src[index * ld_src + k * 2 + d];
}
}
}
}
// convert to vnni format
// from [K/2, 2, N] to [K/2, N, 2] for bfloat16 and float16
template <typename scalar_t, typename index_t>
void pack_vnni2(
scalar_t* __restrict__ dst,
const scalar_t* __restrict__ src,
const index_t* __restrict__ ind,
int K,
int N,
int ld_src,
int ld_dst) {
int k = 0;
for (; k < (K >> 1) * 2; k += 2) {
index_t index0 = get_index(ind, k + 0);
index_t index1 = get_index(ind, k + 1);
for (int n = 0; n < N; ++n) {
dst[(k >> 1) * ld_dst * 2 + n * 2 + 0] = src[index0 * ld_src + n];
dst[(k >> 1) * ld_dst * 2 + n * 2 + 1] = src[index1 * ld_src + n];
}
}
if (K % 2 != 0) {
index_t index = get_index(ind, K - 1);
for (int n = 0; n < N; ++n) {
dst[(K >> 1) * ld_dst * 2 + n * 2 + 0] = src[index * ld_src + n];
dst[(K >> 1) * ld_dst * 2 + n * 2 + 1] = 0;
}
k += 2;
}
// TODO: check whether we can skip this!
// const int padded_K = div_up(K, TILE_K) * TILE_K;
// for (; k < padded_K; ++k) {
// for (int n = 0; n < N; ++n) {
// dst[k * ld_dst + n] = static_cast<scalar_t>(0);
// }
// }
}
template <typename scalar_t>
inline void fill_stub(scalar_t* __restrict__ out, float val, int size) {
using Vec = at::vec::Vectorized<scalar_t>;
const Vec data_vec = Vec(static_cast<scalar_t>(val));
int d = 0;
for (; d <= size - Vec::size(); d += Vec::size()) {
data_vec.store(out + d);
}
if (size - d > 0) {
data_vec.store(out + d, size - d);
}
}
template <typename scalar_t, int BLOCK_N>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input) {
static_assert(BLOCK_N % 32 == 0);
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int COLS = BLOCK_N / 16;
auto store = [&](auto i) {
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
fVec a_fvec0 = fVec::loadu(input + col * 16);
fVec a_fvec1 = fVec::loadu(input + col * 16 + 16);
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + col * 16);
}
};
Unroll<COLS>{}(store);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, float s, int size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec s_fvec = fVec(s);
int d = 0;
for (; d <= size - bVec::size(); d += bVec::size()) {
fVec a_fvec0 = fVec::loadu(acc + d) * s_fvec;
fVec a_fvec1 = fVec::loadu(acc + d + fVec::size()) * s_fvec;
bVec out_bvec = convert_from_float_ext<scalar_t>(a_fvec0, a_fvec1);
out_bvec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(acc[d] * s);
}
}
template <typename scalar_t, typename index_t, int BLOCK_M, int BLOCK_N>
void extend_attention_kernel_impl(
scalar_t* __restrict__ o_extend,
const scalar_t* __restrict__ q_extend,
const scalar_t* __restrict__ k_extend,
const scalar_t* __restrict__ v_extend,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
const index_t* __restrict__ extend_seq_lens,
const index_t* __restrict__ extend_start_loc,
const void* __restrict__ buffer,
int batches,
int num_heads,
int num_heads_kv,
int head_size,
int head_size_v,
int ke_strideN,
int ke_strideH,
int ve_strideN,
int ve_strideH,
int k_strideN,
int k_strideH,
int v_strideN,
int v_strideH,
float scaling,
float logit_cap,
int max_num_reqs,
int max_context_len,
int max_total_num_tokens,
int max_len_extend,
int buffer_size_per_thread,
bool is_prefix_skipped) {
using Vec = at::vec::Vectorized<float>;
// strides
const int q_strideM = num_heads * head_size;
const int q_strideH = head_size;
const int o_strideM = num_heads * head_size_v;
const int o_strideH = head_size_v;
// we use same buffer for packed key and value
const int ldb_tmp = std::max(head_size, head_size_v);
const bool has_logit_cap = logit_cap > 0;
float rlogit_cap = has_logit_cap ? 1 / logit_cap : 0.f;
const int num_groups = num_heads / num_heads_kv;
TORCH_CHECK(num_groups * num_heads_kv == num_heads);
// number of blocks along M
int MB = div_up(max_len_extend, BLOCK_M);
// parallel on [batches, num_heads, BM]
at::parallel_for(0, batches * num_heads * MB, 0, [&](int begin, int end) {
int bs{0}, head_id{0}, mb{0};
data_index_init(begin, bs, batches, head_id, num_heads, mb, MB);
int tid = at::get_thread_num();
// s_i and s_delta: [BLOCK_M, BLOCK_N]
float* __restrict__ s_i = reinterpret_cast<float*>((char*)(buffer) + tid * buffer_size_per_thread);
float* __restrict__ s_delta = s_i;
// v_prime: [BLOCK_M, head_size_v]
float* __restrict__ v_prime = s_i + BLOCK_M * BLOCK_N;
// s_delta2: [BLOCK_M, BLOCK_N]; copy of s_delta in scalar_t
scalar_t* __restrict__ s_delta2 = reinterpret_cast<scalar_t*>(v_prime + BLOCK_N * head_size_v);
// Btmp: [BLOCK_N, max(head_size, head_size_v)]
scalar_t* __restrict__ Btmp = s_delta2 + BLOCK_M * BLOCK_N;
// init Btmp just once for each thread to prevent NaN
fill_stub(Btmp, 0.f, BLOCK_N * ldb_tmp);
alignas(64) float s_prime[BLOCK_M];
alignas(64) float m_prime[BLOCK_M];
for (int i = begin; i < end; ++i) {
// seq_len = prefix + extend
int head_kv_id = head_id / num_groups;
int seq_len = seq_lens[bs];
int seq_len_extend = extend_seq_lens[bs];
int seq_len_prefix = seq_len - seq_len_extend;
int seq_extend_start_loc = extend_start_loc[bs];
int req_pool_id = req_pool_indices[bs];
TORCH_CHECK(seq_len_prefix >= 0, "prefix len < 0!");
TORCH_CHECK(seq_len <= max_context_len, "seq_len out of scope!");
TORCH_CHECK(req_pool_id < max_num_reqs, "req_pool_id out of scope!");
if (is_prefix_skipped) {
TORCH_CHECK(seq_len_prefix == 0, "extend attention: expect seq_len_prefix to be 0, got ", seq_len_prefix);
}
// offset and size in MB
int m = mb * BLOCK_N;
int m_size = std::min(BLOCK_M, seq_len_extend - m);
if (m_size <= 0) {
data_index_step(bs, batches, head_id, num_heads, mb, MB);
continue;
}
// get query
const scalar_t* __restrict__ q_ptr = q_extend + (seq_extend_start_loc + m) * q_strideM + head_id * q_strideH;
// init v', s' and m'
fill_stub(v_prime, 0.f, m_size * head_size_v);
fill_stub(s_prime, 0.f, m_size);
fill_stub(m_prime, -std::numeric_limits<scalar_t>::infinity(), m_size);
// stage 1: compute scores with prefix
for (int n = 0; n < seq_len_prefix; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, seq_len_prefix - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ k_buffer + head_kv_id * k_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ k_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
const Vec scale_vec = Vec(scaling);
for (int row = 0; row < m_size; ++row) {
// s_i <- s_i * scale
at::vec::map<float>(
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// TODO: `tanh` from torch uses sleef u10, going to be slow
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i + row * BLOCK_N,
s_i + row * BLOCK_N,
n_size);
}
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[row]);
// m_delta <- exp(m' - m_i)
float m_delta = std::exp(m_prime[row] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[row] *= m_delta;
s_prime[row] +=
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
m_prime[row] = m_i;
// v' <- v' * m_delta
at::vec::map<float>(
[m_delta](Vec x) { return x * Vec(m_delta); },
v_prime + row * head_size_v,
v_prime + row * head_size_v,
head_size_v);
// pad s_delta with 0 first and then convert to scalar_t
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
}
// get value and pack
pack_vnni2<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ v_buffer + head_kv_id * v_strideH,
/* ind */ req_to_token + req_pool_id * max_context_len + n,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ v_strideN,
/* ld_dst */ head_size_v);
// caculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta2,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_prefix
// stage 2: compute the triangle part
int num_keys = std::min(seq_len_extend, m + BLOCK_M);
for (int n = 0; n < num_keys; n += BLOCK_N) {
int n_size = std::min(BLOCK_N, num_keys - n);
// `n_size` is K in 2nd gemm, pad to TILE_K;
const int padded_n_size = div_up(n_size, TILE_K) * TILE_K;
// get key and pack
pack_vnni<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ k_extend + (seq_extend_start_loc + n) * ke_strideN + head_kv_id * ke_strideH,
/* ind */ nullptr,
/* N */ n_size,
/* K */ head_size,
/* ld_src */ ke_strideN,
/* ld_dst */ BLOCK_N);
// calculate s_i <- Q @ K
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ head_size,
/* lda */ q_strideM,
/* ldb */ BLOCK_N,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ q_ptr,
/* B */ Btmp,
/* C */ s_i);
// apply causal mask
if (num_keys - n <= BLOCK_N) {
for (int row = 0; row < m_size; ++row) {
int last_col = m + row - n;
// fill [last_col + 1, n_size) to -inf
float* row_ptr = s_i + row * BLOCK_N;
fill_stub(row_ptr + last_col + 1, -std::numeric_limits<float>::infinity(), n_size - last_col - 1);
}
}
const Vec scale_vec = Vec(scaling);
for (int row = 0; row < m_size; ++row) {
// s_i <- s_i * scale
at::vec::map<float>(
[scale_vec](Vec x) { return x * scale_vec; }, s_i + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// TODO: `tanh` from torch uses sleef u10, going to be slow
if (has_logit_cap) {
at::vec::map<float>(
[logit_cap, rlogit_cap](Vec x) { return Vec(logit_cap) * (x * Vec(rlogit_cap)).tanh(); },
s_i + row * BLOCK_N,
s_i + row * BLOCK_N,
n_size);
}
// m_i: max value per row
float m_i = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); }, s_i + row * BLOCK_N, n_size);
m_i = std::max(m_i, m_prime[row]);
// m_delta <- exp(m' - m_i)
float m_delta = std::exp(m_prime[row] - m_i);
// s_delta <- exp(s_i - m_i)
at::vec::map<float>(
[m_i](Vec x) { return (x - Vec(m_i)).exp_u20(); }, s_delta + row * BLOCK_N, s_i + row * BLOCK_N, n_size);
// s' <- s' * m_delta + sum(s_delta)
s_prime[row] *= m_delta;
s_prime[row] +=
at::vec::reduce_all<float>([](Vec& x, Vec& y) { return x + y; }, s_delta + row * BLOCK_N, n_size);
m_prime[row] = m_i;
// v' <- v' * m_delta
at::vec::map<float>(
[m_delta](Vec x) { return x * Vec(m_delta); },
v_prime + row * head_size_v,
v_prime + row * head_size_v,
head_size_v);
// pad s_delta with 0 first and then convert to scalar_t
fill_stub(s_delta + row * BLOCK_N + n_size, 0.f, padded_n_size - n_size);
copy_stub<scalar_t, BLOCK_N>(s_delta2 + row * BLOCK_N, s_delta + row * BLOCK_N);
}
// get value and pack
pack_vnni2<scalar_t, index_t>(
/* dst */ Btmp,
/* src */ v_extend + (seq_extend_start_loc + n) * ve_strideN + head_kv_id * ve_strideH,
/* ind */ nullptr,
/* K */ n_size,
/* N */ head_size_v,
/* ld_src */ ve_strideN,
/* ld_dst */ head_size_v);
// caculate V' <- s_delta @ V + V'
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ head_size_v,
/* K */ padded_n_size, // n_size
/* lda */ BLOCK_N,
/* ldb */ head_size_v,
/* ldc */ head_size_v,
/* add_C */ true,
/* A */ s_delta2,
/* B */ Btmp,
/* C */ v_prime);
} // loop with seq_len_extend
scalar_t* __restrict__ out_ptr = o_extend + (seq_extend_start_loc + m) * o_strideM + head_id * o_strideH;
for (int row = 0; row < m_size; ++row) {
float s = 1 / s_prime[row];
copy_stub<scalar_t>(out_ptr + row * o_strideM, v_prime + row * head_size_v, s, head_size_v);
}
// move to the next index
data_index_step(bs, batches, head_id, num_heads, mb, MB);
}
at::native::cpublas::brgemm_release();
});
}
} // anonymous namespace
// q_extend, k_extend, v_extend, o_extend: contiguous tensors
// k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
//
// q_extend: [num_tokens, num_heads, head_size]
// k_extend: [num_extend_tokens, num_heads, head_size]
// v_extend: [num_extend_tokens, num_heads, head_size]
// o_extend: [num_tokens, num_heads, head_size]
// k_buffer: [max_total_num_tokens, num_heads, head_size]
// v_buffer: [max_total_num_tokens, num_heads, head_size]
// req_to_token: [max_num_reqs, max_context_len] int32 or int64
// req_pool_indices: [num_seqs] int64
// seq_lens: [num_seqs] int64
// extend_seq_lens: [num_seqs]
// extend_start_loc: [num_seqs]
//
void extend_attention_cpu(
at::Tensor& q_extend,
at::Tensor& k_extend,
at::Tensor& v_extend,
at::Tensor& o_extend,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
at::Tensor& extend_seq_lens,
at::Tensor& extend_start_loc,
int64_t max_len_extend,
double sm_scale,
double logit_cap) {
RECORD_FUNCTION(
"sgl-kernel::extend_attention_cpu",
std::vector<c10::IValue>(
{q_extend,
k_extend,
v_extend,
o_extend,
k_buffer,
v_buffer,
req_to_token,
req_pool_indices,
seq_lens,
extend_seq_lens,
extend_start_loc}));
CHECK_INPUT(q_extend);
CHECK_INPUT(o_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
int num_seqs = seq_lens.size(0);
int max_num_reqs = req_to_token.size(0);
int max_context_len = req_to_token.size(1);
int max_total_num_tokens = k_buffer.size(0);
int num_heads = q_extend.size(1);
int num_heads_kv = k_extend.size(1);
int head_size = q_extend.size(2);
int head_size_v = v_extend.size(2);
// strides for k_extend and v_extend
int ke_strideN = k_extend.stride(0);
int ke_strideH = k_extend.stride(1);
int ve_strideN = v_extend.stride(0);
int ve_strideH = v_extend.stride(1);
// strides for k_buffer and v_buffer
int k_strideN = k_buffer.stride(0);
int k_strideH = k_buffer.stride(1);
int v_strideN = v_buffer.stride(0);
int v_strideH = v_buffer.stride(1);
// check sizes
CHECK_EQ(req_pool_indices.size(0), num_seqs);
CHECK_EQ(extend_seq_lens.size(0), num_seqs);
CHECK_EQ(extend_start_loc.size(0), num_seqs);
CHECK_EQ(v_extend.size(1), num_heads_kv);
CHECK_EQ(k_buffer.size(1), v_buffer.size(1));
// MLA will skip prefix part
const bool is_prefix_skipped = k_buffer.size(1) != num_heads_kv;
// check index data types
const auto index_dtype = req_to_token.scalar_type();
TORCH_CHECK(
index_dtype == at::kInt || index_dtype == at::kLong,
"extend: expect req_to_token to be int32 or int64, got ",
index_dtype);
TORCH_CHECK(seq_lens.scalar_type() == at::kLong, "extend: expect req_lens to be int64, got ", seq_lens.scalar_type());
TORCH_CHECK(
req_pool_indices.scalar_type() == at::kLong,
"extend: expect req_pool_indices to be int64, got ",
req_pool_indices.scalar_type());
TORCH_CHECK(
extend_seq_lens.scalar_type() == index_dtype && extend_start_loc.scalar_type() == index_dtype,
"extend: expect extend_seq_lens and extend_start_loc to have same dtype as req_to_token.");
// D and DV need to be 32x as we transpose by 512-bit
TORCH_CHECK(head_size % 32 == 0, "invalid head_size ", head_size);
TORCH_CHECK(head_size_v % 32 == 0, "invalid head_size_v ", head_size_v);
// block size for query seq length
constexpr int BLOCK_M = 32;
// block size for key/value seq length
constexpr int BLOCK_N = 32;
const int size_per_thread =
/* s_i */ BLOCK_M * BLOCK_N * sizeof(float) +
/* v_prime */ BLOCK_M * head_size_v * sizeof(float) +
/* s_delta */ BLOCK_M * BLOCK_N * sizeof(uint16_t) +
/* Btmp */ BLOCK_N * std::max(head_size, head_size_v) * sizeof(uint16_t);
int num_threads = at::get_num_threads();
auto buffer = at::empty({num_threads, size_per_thread}, q_extend.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(q_extend.scalar_type(), "extend_attention_kernel", [&] {
AT_DISPATCH_INDEX_TYPES(index_dtype, "extend_attention_indices", [&] {
extend_attention_kernel_impl<scalar_t, index_t, BLOCK_M, BLOCK_N>(
o_extend.data_ptr<scalar_t>(),
q_extend.data_ptr<scalar_t>(),
k_extend.data_ptr<scalar_t>(),
v_extend.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(),
v_buffer.data_ptr<scalar_t>(),
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
extend_seq_lens.data_ptr<index_t>(),
extend_start_loc.data_ptr<index_t>(),
buffer.data_ptr(),
num_seqs,
num_heads,
num_heads_kv,
head_size,
head_size_v,
ke_strideN,
ke_strideH,
ve_strideN,
ve_strideH,
k_strideN,
k_strideH,
v_strideN,
v_strideH,
sm_scale,
logit_cap,
max_num_reqs,
max_context_len,
max_total_num_tokens,
max_len_extend,
size_per_thread,
is_prefix_skipped);
});
});
}
#include "gemm.h"
#include "common.h"
#include "vec.h"
namespace {
// packed layout:
// quants {N, K} int8_t
// comp {N} int32_t
template <int BLOCK_N>
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
__m512i vcomp[COLS];
for (int col = 0; col < COLS; ++col) {
vcomp[col] = _mm512_setzero_si512();
}
const int64_t offset = BLOCK_N * K;
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
for (int k = 0; k < K / 4; ++k) {
for (int col = 0; col < COLS; ++col) {
__m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64));
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
}
}
for (int col = 0; col < COLS; ++col) {
_mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]);
}
#else
TORCH_CHECK(false, "s8s8_compensation not implemented!");
#endif
}
// convert to vnni format
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
template <typename packed_t>
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
const int VNNI_BLK = 2;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
}
template <>
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
constexpr int BLOCK_N = block_size_n();
TORCH_CHECK(N == BLOCK_N);
const int VNNI_BLK = 4;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
s8s8_compensation<BLOCK_N>(packed, K);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::BFloat16* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_set1_ps(0.f);
}
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const float* b_ptr = reinterpret_cast<const float*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
// for COLS = 1, 3 use 256bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// copy from Ctmp to C
for (int64_t m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
if (brg) {
brgemm<scalar_t, has_bias>::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void weight_packed_linear_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
/* C */ out + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const TYPE* __restrict__ B, \
TYPE* __restrict__ C, \
float* __restrict__ Ctmp, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor convert_weight_packed(at::Tensor& weight) {
// for 3d moe weights
// weight : [E, OC, IC]
// w1 : [E, 2N, K]
// w2 : [E, K, N]
CHECK_INPUT(weight);
const int64_t ndim = weight.ndimension();
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
const auto st = weight.scalar_type();
const int64_t E = ndim == 3 ? weight.size(0) : 1;
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
// we handle 2 TILE_N at a time.
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
constexpr int64_t BLOCK_N = block_size_n();
const int64_t NB = div_up(OC, BLOCK_N);
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
auto packed_weight = at::empty({}, weight.options());
const int64_t stride = OC * IC;
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf || st == at::kChar, "expect weight to be bfloat16, float16 or int8.");
CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size
const int packed_row_size = get_row_size<packed_t>(IC);
auto sizes = weight.sizes().vec();
sizes[ndim - 1] = packed_row_size;
packed_weight.resize_(sizes);
const packed_t* w_data = weight.data_ptr<packed_t>();
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
// parallel on {E, NB}
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
int64_t e{0}, nb{0};
data_index_init(begin, e, E, nb, NB);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int64_t n = nb * BLOCK_N;
int64_t n_size = std::min(BLOCK_N, OC - n);
pack_vnni<packed_t>(
packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC);
// move to the next index
data_index_step(e, E, nb, NB);
}
});
});
return packed_weight;
}
// mat1 : [M, K]
// mat2 : [N, K]
// bias : [N]
// out : [M, N]
//
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
auto out = at::empty({M, N}, mat1.options());
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
weight_packed_linear_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
bias_data,
M,
N,
K,
mat1_strideM,
out_strideM);
});
return out;
}
#pragma once
#include <ATen/native/CPUBlas.h>
// amx-bf16
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
// block size for AMX gemm
constexpr int block_size_m() {
return 2 * TILE_M;
}
constexpr int block_size_n() {
return 2 * TILE_N;
}
// define threshold using brgemm (intel AMX)
template <typename T>
inline bool can_use_brgemm(int M);
template <>
inline bool can_use_brgemm<at::BFloat16>(int M) {
return M > 4;
}
template <>
inline bool can_use_brgemm<at::Half>(int M) {
return true;
}
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
template <>
inline bool can_use_brgemm<int8_t>(int M) {
return false;
}
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K
// adjust leading dimension size for K
template <typename T>
inline int64_t get_row_size(int64_t K) {
return K;
}
template <>
inline int64_t get_row_size<int8_t>(int64_t K) {
return K + sizeof(int32_t);
}
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
// pack weight to vnni format
at::Tensor convert_weight_packed(at::Tensor& weight);
// moe implementations for int8 w8a8
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// shared expert implememntation for int8 w8a8
template <typename scalar_t>
void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512i va;
__m512i vb[COLS];
__m512i vc[ROWS * COLS];
__m512i vcomp[COLS];
__m512 vd0;
__m512 vd1[COLS];
// oops! 4x4 spills but luckly we use 4x2
__m512 vbias[COLS];
// [NOTE]: s8s8 igemm compensation in avx512-vnni
//
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
//
// a * b = (a + 128) * b - 128 * b
// s s u s u s
//
// 1) 128 * b is pre-computed when packing B to vnni formats
// 2) a + 128 is fused when dynamically quantize A
//
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr (col == 0) {
vd0 = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp per 2 vectors
// also load bias if any
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
if constexpr (has_bias) {
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
}
}
}
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
if constexpr (has_bias) {
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
} else {
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
}
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 4, \
C + mb_start * ldc + nb_start, \
As + mb_start, \
Bs + nb_start, \
Bcomp + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
// B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void int8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const uint8_t* __restrict__ mat1,
const int8_t* __restrict__ mat2,
const float* __restrict__ scales1,
const float* __restrict__ scales2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const bool use_brgemm = false;
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use int32_t for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * K,
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ out + mb_start * N + nb_start,
/* Ctmp*/ Ctmp,
/* As */ scales1 + mb_start,
/* Bs */ scales2 + nb_start,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ N,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const uint8_t* __restrict__ A, \
const int8_t* __restrict__ B, \
TYPE* __restrict__ C, \
int32_t* __restrict__ Ctmp, \
const float* __restrict__ As, \
const float* __restrict__ Bs, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
CHECK_DIM(2, A);
int64_t M = A.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
const auto st = A.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half.");
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
auto As = at::empty({M}, A.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
float* __restrict__ As_data = As.data_ptr<float>();
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
});
return std::make_tuple(Aq, As);
}
// weight : static, per-channel, symmetric
// activation : dynamic, per-token, symmetric
//
// mat1 : [M, K]
// mat2 : [N, K]
// scales1 : [M]
// scales2 : [N]
// bias : [N]
// out : [M, N]
//
at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales1,
at::Tensor& scales2,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales1);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales1.numel(), M);
CHECK_EQ(scales2.numel(), N);
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
TORCH_CHECK(
scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
"int8_scaled_mm: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<uint8_t>(),
packed_w.data_ptr<int8_t>(),
scales1.data_ptr<float>(),
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
int64_t lda = mat1.stride(0);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales2.numel(), N);
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32.");
const int64_t buffer_size = M * K + M * sizeof(float);
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
Aq_data,
packed_w.data_ptr<int8_t>(),
As_data,
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}
#include <ATen/record_function.h>
#include <torch/extension.h>
#include "shm.h"
// Communication settings
static int world_rank = -1;
static int world_size = -1;
static bool is_initialized = false;
static bool all_ranks_local_p = false;
void initialize(int size, int rank) {
if (is_initialized) {
return;
}
// Check whether all ranks is on the same physical machine.
// If true, we will use an SHM based low latency allreduce
auto ls_string = std::getenv("LOCAL_SIZE");
int ls = 0;
if (ls_string != NULL) {
ls = std::stoi(std::getenv("LOCAL_SIZE"));
}
if (size >= 1 && size == ls) {
all_ranks_local_p = true;
}
world_size = size;
world_rank = rank;
is_initialized = true;
auto addr_string = std::getenv("MASTER_ADDR");
if (addr_string == NULL) {
addr_string = "";
}
auto port_string = std::getenv("MASTER_PORT");
if (port_string == NULL) {
port_string = "";
}
if (all_ranks_local_p) {
shm_initialize(size, rank, addr_string, port_string);
}
}
void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op) {
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported");
auto numel = data.numel();
int data_size = 0;
bool data_type_fallback = false;
switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<torch::Tensor> tensors = {data};
process_group->allreduce(tensors)->wait();
} else {
all_reduce_outer_loop(data, numel, data_size);
}
return;
}
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim) {
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
auto numel = data.numel();
int data_size = 0;
bool data_type_fallback = false;
switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
if (dim < 0) {
dim += data.dim();
}
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<std::vector<torch::Tensor>> output_tensors(1);
auto world_size = process_group->getSize();
for (int i = 0; i < world_size; i++) {
output_tensors[0].push_back(torch::empty_like(data));
}
std::vector<torch::Tensor> input_tensors = {data};
process_group->allgather(output_tensors, input_tensors)->wait();
return torch::cat(output_tensors[0], dim).contiguous();
}
std::vector<int64_t> result_shape = data.sizes().vec();
result_shape[dim] *= world_size;
torch::Tensor result_tensor = torch::empty(result_shape, data.options());
return all_gather(result_tensor, data, dim, numel, data_size);
}
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// [NOTE]: Fused MoE kernel with AMX
//
// This file contains implementations for
// * `moe_align_block_size`
// * `fused_moe`
//
// The functionality is identical to triton kernel, excepts:
// * fuse silu_and_mul with gemm1, therefore this kernel
// allocates 2 intermediate_caches instead of 3
// * add `offsets` in `moe_align_block_size` which keeps track
// of starting offset for each M block. this is for keeping
// output of silu_and_mul in sorted order, thus load_A for
// the 2nd gemm would be contiguous, therefore we can directly
// load A from intermediate_cache1.
//
// TODO:
// 1. tune BLOCK_M and BLOCK_N (BLOCK_N * K fit L2)
// 2. add prefetch for load A which is indexed access
// 3. abstract at::native::cpublas::brgemm with WoQ gemm (M = 1 & M != 1)
//
template <typename scalar_t>
inline void fill_stub(scalar_t* __restrict__ out, scalar_t val, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
const Vec data_vec(val);
at::vec::map<scalar_t>([data_vec](Vec out) { return out = data_vec; }, out, out, size);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
Vec data = Vec::loadu(input + d);
data.store(out + d);
}
}
template <typename scalar_t>
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec weight_vec = fVec(weight);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) * weight_vec;
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] * weight);
}
}
// acc from [topk, K] to [K]
template <typename scalar_t>
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
if (topk == 1) {
// do copy for topk = 1
copy_stub(out, input, K);
} else {
// do sum for topk != 1
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= K - kVecSize; d += kVecSize) {
fVec sum_fvec0 = fVec(0.f);
fVec sum_fvec1 = fVec(0.f);
for (int t = 0; t < topk; ++t) {
bVec x_bvec = bVec::loadu(input + t * K + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec0 += x_fvec0;
sum_fvec1 += x_fvec1;
}
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
out_bvec.store(out + d);
}
for (; d < K; ++d) {
float sum_val = 0.f;
for (int t = 0; t < topk; ++t) {
sum_val += static_cast<float>(input[t * K + d]);
}
out[d] = static_cast<scalar_t>(sum_val);
}
}
}
// out = input + input2 * scale
template <typename scalar_t>
inline void add_mul_stub(
scalar_t* __restrict__ out,
const float* __restrict__ input,
const scalar_t* __restrict__ input2,
float scale,
int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_vec = fVec(scale);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec x0 = fVec::loadu(input + d);
fVec x1 = fVec::loadu(input + d + fVec::size());
bVec y_bvec = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
x0 = x0 + y0 * s_vec;
x1 = x1 + y1 * s_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
}
}
template <int BLOCK_M>
int moe_align_block_size(
int32_t* __restrict__ sorted_ids,
int32_t* __restrict__ expert_ids,
int32_t* __restrict__ topk_ids,
int32_t* __restrict__ total_cnts,
int32_t* __restrict__ cumsums,
int32_t* __restrict__ offsets,
int num_experts,
int numel,
int num_threads) {
#define T_INDEX(tt) total_cnts + (tt) * num_experts
// accumulate count of expert ids locally
at::parallel_for(0, numel, 0, [&](int begin, int end) {
int tid = at::get_thread_num();
int32_t* __restrict__ local_cnts = T_INDEX(tid + 1);
for (int i = begin; i < end; ++i) {
local_cnts[topk_ids[i]]++;
}
});
using iVec = at::vec::Vectorized<int32_t>;
for (int t = 0; t < num_threads; ++t) {
at::vec::map2<int32_t>(
[](iVec x, iVec y) { return x + y; }, T_INDEX(t + 1), T_INDEX(t + 1), T_INDEX(t), num_experts);
}
// the last row holds sums of each experts
int32_t* total_cnts_t_1 = T_INDEX(num_threads);
cumsums[0] = 0;
for (int e = 0; e < num_experts; ++e) {
// accumulate `num_tokens_post_pad`, also as the expert offset
cumsums[e + 1] = cumsums[e] + div_up(total_cnts_t_1[e], BLOCK_M) * BLOCK_M;
for (int k = cumsums[e]; k < cumsums[e + 1]; k += BLOCK_M) {
expert_ids[k / BLOCK_M] = e;
}
}
int num_tokens_post_pad = cumsums[num_experts];
at::parallel_for(0, numel, 0, [&](int begin, int end) {
int tid = at::get_thread_num();
// thread tid offsets in `total_cnts`
int32_t* __restrict__ offsets = T_INDEX(tid);
for (int i = begin; i < end; ++i) {
int32_t expert_id = topk_ids[i];
int32_t b_offset = cumsums[expert_id];
int32_t t_offset = offsets[expert_id];
sorted_ids[b_offset + t_offset] = i;
offsets[expert_id]++;
}
});
// debug: the offset for thread t_1 should be identical to t_2
int32_t* total_cnts_t_2 = T_INDEX(num_threads - 1);
for (int e = 0; e < num_experts; ++e) {
TORCH_CHECK(total_cnts_t_1[e] == total_cnts_t_2[e]);
}
// padding value for sorted_ids: numel
auto sorted_id_size = [=](const int32_t* sorted_ids_ptr) {
for (int d = 0; d < BLOCK_M; ++d) {
if (sorted_ids_ptr[d] == numel) {
return d;
}
}
return BLOCK_M;
};
// offsets holds starting offset for each valida M blocks
// shape : [num_token_blocks + 1]
offsets[0] = 0;
const int num_token_blocks = num_tokens_post_pad / BLOCK_M;
at::parallel_for(0, num_token_blocks, GRAIN_SIZE / BLOCK_M, [&](int begin, int end) {
for (int mb = begin; mb < end; ++mb) {
offsets[mb + 1] = sorted_id_size(sorted_ids + mb * BLOCK_M);
}
});
// TODO: do we need to vecterize this ?
for (int mb = 0; mb < num_token_blocks; ++mb) {
offsets[mb + 1] += offsets[mb];
}
// debug: the last value of offsets should be `numel`
TORCH_CHECK(offsets[num_token_blocks] == numel);
return num_tokens_post_pad;
}
// silu : shape leading dimension
// input0 [m_size, BLOCK_N] BLOCK_N
// input1 [m_size, BLOCK_N] BLOCK_N
// output [M * topk, N] N
template <typename scalar_t, int BLOCK_N>
inline void silu_and_mul(
scalar_t* __restrict__ output,
const float* __restrict__ input0, // x: x0, x1
const float* __restrict__ input1, // y: y0, y1
int64_t m_size,
int64_t N) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
// no remainder
for (int64_t m = 0; m < m_size; ++m) {
scalar_t* __restrict__ out = output + m * N;
const float* __restrict__ x = input0 + m * BLOCK_N;
const float* __restrict__ y = input1 + m * BLOCK_N;
for (int64_t d = 0; d < BLOCK_N; d += bVec::size()) {
fVec x0 = fVec::loadu(x + d);
fVec x1 = fVec::loadu(x + d + fVec::size());
fVec y0 = fVec::loadu(y + d);
fVec y1 = fVec::loadu(y + d + fVec::size());
// silu
x0 = x0 / (one + x0.neg().exp_u20());
x1 = x1 / (one + x1.neg().exp_u20());
// mul
x0 = x0 * y0;
x1 = x1 * y1;
// convert
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
}
}
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn2 {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B0,
const scalar_t* __restrict__ B1,
scalar_t* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn2<at::BFloat16, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::BFloat16* __restrict__ B0,
const at::BFloat16* __restrict__ B1,
at::BFloat16* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb0[COLS];
__m512bh vb1[COLS];
__m512 vc0[ROWS * COLS];
__m512 vc1[ROWS * COLS];
auto loadc = [&](auto i) {
vc0[i] = _mm512_set1_ps(0.f);
vc1[i] = _mm512_set1_ps(0.f);
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const float* b0_ptr = reinterpret_cast<const float*>(B0);
const float* b1_ptr = reinterpret_cast<const float*>(B1);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
vb0[col] = (__m512bh)(_mm512_loadu_si512(b0_ptr + k * ldb2 + col * 16));
vb1[col] = (__m512bh)(_mm512_loadu_si512(b1_ptr + k * ldb2 + col * 16));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b0_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
_mm_prefetch(b1_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc0[i] = _mm512_dpbf16_ps(vc0[i], va, vb0[col]);
vc1[i] = _mm512_dpbf16_ps(vc1[i], va, vb1[col]);
};
for (int64_t k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
using Vec = at::vec::Vectorized<float>;
const Vec one = Vec(1.f);
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
Vec x0 = vc0[row * COLS + col + 0];
Vec x1 = vc0[row * COLS + col + 1];
Vec y0 = vc1[row * COLS + col + 0];
Vec y1 = vc1[row * COLS + col + 1];
// silu
x0 = x0 / (one + x0.neg().exp_u20());
x1 = x1 / (one + x1.neg().exp_u20());
// mul
x0 = x0 * y0;
x1 = x1 * y1;
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B0 + nb_start * 2, B1 + nb_start * 2, C + mb_start * ldc + nb_start, K, lda, ldb, ldc);
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B0,
const scalar_t* __restrict__ B1,
scalar_t* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
// pattern: 1-(2+2)-(8+8)
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 32;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::BFloat16* __restrict__ B,
float* __restrict__ C,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) { vc[i] = _mm512_set1_ps(0.f); };
Unroll<ROWS * COLS>{}(loadc);
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const float* b_ptr = reinterpret_cast<const float*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), vc[i]);
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN2(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B + nb_start * 2, C + mb_start * ldc + nb_start, K, lda, ldb, ldc);
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
float* __restrict__ C,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
// pattern: 1-2-8
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 32;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN2(1, 32);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN2(2, 32);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN2(3, 32);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN2(4, 32);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void fused_experts_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
scalar_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ packed_w1,
const scalar_t* __restrict__ packed_w2,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad) {
// handle 2 tiles per block
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// strides for w1: [E, 2N, K]
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
const int64_t stride_e = 2 * N * K;
const int64_t stride_n = K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
scalar_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const scalar_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb];
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, input + index * K, K);
}
if (use_brgemm) {
// 1.b gemm: C0 = A @ B0
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B0,
/* C */ C0);
// 1.c gemm: C1 = A @ B1
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B1,
/* C */ C1);
// 1.d silu and mul
const int64_t offset = offsets[mb];
silu_and_mul<scalar_t, BLOCK_N>(ic1 + offset * N + nb * BLOCK_N, C0, C1, m_size, N);
} else {
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
const int64_t offset = offsets[mb];
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + offset * N + nb * BLOCK_N,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
}
if (is_brgemm_used) {
at::native::cpublas::brgemm_release();
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [E, K, N] as [E, OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
const int64_t stride_e2 = OC * IC;
const int64_t stride_oc = IC;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const scalar_t* __restrict__ A = ic1 + offsets[mb] * N;
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
// B shape [IC, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const scalar_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
// 2.a gemm: C = A @ B
if (use_brgemm) {
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B,
/* C */ C);
} else {
tinygemm_kernel(
/* A */ A,
/* B */ B,
/* C */ C,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
}
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m];
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
}
if (is_brgemm_used) {
at::native::cpublas::brgemm_release();
}
});
// stage 3: out = intermediate_cache2.sum(dim=1)
// from [M, topk, K] to [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
}
});
}
template <typename scalar_t>
void shared_expert_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
scalar_t* __restrict__ input,
const scalar_t* __restrict__ packed_w1,
const scalar_t* __restrict__ packed_w2,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K) {
// handle 2 tiles per block
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
const int64_t stride_n = K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
float* __restrict__ C0 = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
float* __restrict__ C1 = C0 + BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
// int64_t mb_start = mb * BLOCK_M;
// int64_t mb_size = std::min(M - mb_start, BLOCK_M);
// A shape [m_size, K]
const scalar_t* A = input + mb * BLOCK_M * K;
// B shape [K, n_size] in vnni format
const scalar_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
const scalar_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
if (use_brgemm) {
// 1.b gemm: C0 = A @ B0
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B0,
/* C */ C0);
// 1.c gemm: C1 = A @ B1
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B1,
/* C */ C1);
// 1.d silu and mul
silu_and_mul<scalar_t, BLOCK_N>(ic1 + mb * BLOCK_M * N + nb * BLOCK_N, C0, C1, m_size, N);
} else {
// fused 1.bcd: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
}
if (is_brgemm_used) {
at::native::cpublas::brgemm_release();
}
});
// stage 2: output = intermediate_cache1 @ w2
// w2 : [K, N] as [OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
const int64_t stride_oc = IC;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
bool is_brgemm_used = false;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
const bool use_brgemm = can_use_brgemm<scalar_t>(m_size);
is_brgemm_used = is_brgemm_used || use_brgemm;
// A shape [m_size, IC]
const scalar_t* __restrict__ A = ic1 + mb * BLOCK_M * N;
// B shape [IC, n_size] in vnni format
const scalar_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
// 2.a gemm: C = A @ B
if (use_brgemm) {
at::native::cpublas::brgemm(
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N,
/* add_C */ false,
/* A */ A,
/* B */ B,
/* C */ C);
} else {
tinygemm_kernel(
/* A */ A,
/* B */ B,
/* C */ C,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
}
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
}
if (is_brgemm_used) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
// hidden_states: [M, K]
// w1: [E, 2N, K]
// w2: [E, K, N]
// topk_weights: [M, topk]
// topk_ids: [M, topk] (int32_t)
//
at::Tensor fused_experts_cpu(
at::Tensor& hidden_states,
at::Tensor& w1,
at::Tensor& w2,
at::Tensor& topk_weights,
at::Tensor& topk_ids,
bool inplace,
bool use_int8_w8a8,
std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale,
std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale,
bool is_vnni) {
RECORD_FUNCTION(
"sgl-kernel::fused_experts_cpu", std::vector<c10::IValue>({hidden_states, w1, w2, topk_weights, topk_ids}));
auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1);
auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2);
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const auto st = hidden_states.scalar_type();
CHECK_INPUT(hidden_states);
CHECK_INPUT(w1);
CHECK_INPUT(w2);
CHECK_EQ(topk_weights.sizes(), topk_ids.sizes());
CHECK_DIM(2, hidden_states);
CHECK_DIM(3, w1);
CHECK_DIM(3, w2);
CHECK_DIM(2, topk_weights);
CHECK_DIM(2, topk_ids);
CHECK_EQ(topk_ids.scalar_type(), at::kInt);
CHECK_EQ(topk_weights.scalar_type(), at::kFloat);
int64_t M = hidden_states.size(0);
int64_t K = hidden_states.size(1);
int64_t N = w1.size(1) / 2;
int64_t E = w1.size(0);
int64_t topk = topk_weights.size(1);
// we use int32_t compensation for int8 w8a8
int64_t packed_K = get_row_size(K, use_int8_w8a8);
int64_t packed_N = get_row_size(N, use_int8_w8a8);
// check weight shapes
CHECK_EQ(w2.size(0), E);
CHECK_EQ(w2.size(1), K);
CHECK_EQ(packed_w1.size(2), packed_K);
CHECK_EQ(packed_w2.size(2), packed_N);
if (use_int8_w8a8) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
}
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
// NB: worst case is each expert holds a block with remainder of 1
// 1. sorted_ids : [M * topk + E * (BLOCK_M - 1)]
// 2. expert_ids : [max_num_blocks]
// 3. total_cnts : [T + 1, E]
// 4. cumsums : [E + 1]
// 5. offsets : [max_num_blocks + 1]
//
int num_threads = at::get_num_threads();
int64_t max_num_tokens_padded = M * topk + E * (BLOCK_M - 1);
int64_t max_num_blocks = div_up(max_num_tokens_padded, BLOCK_M);
auto buffer = at::empty(
{max_num_tokens_padded + max_num_blocks + (num_threads + 1) * E + (E + 1) + (max_num_blocks + 1)},
topk_ids.options());
int32_t* __restrict__ sorted_ids = buffer.data_ptr<int32_t>();
int32_t* __restrict__ expert_ids = sorted_ids + max_num_tokens_padded;
int32_t* __restrict__ total_cnts = expert_ids + max_num_blocks;
int32_t* __restrict__ cumsums = total_cnts + (num_threads + 1) * E;
int32_t* __restrict__ offsets = cumsums + (E + 1);
// init sorted_ids with `numel` as the padding number
// init expert_ids with `num_experts`
int64_t numel = M * topk;
at::parallel_for(0, max_num_blocks, GRAIN_SIZE / BLOCK_M, [&](int64_t begin, int64_t end) {
int64_t m_start = begin * BLOCK_M;
int64_t m_size = std::min((end - begin) * BLOCK_M, max_num_tokens_padded - m_start);
fill_stub(sorted_ids + m_start, (int32_t)numel, m_size);
fill_stub(expert_ids + begin, (int32_t)E, end - begin);
});
// zero total_cnts and cumsums
at::parallel_for(0, (num_threads + 1) * E + (E + 1), GRAIN_SIZE, [&](int64_t begin, int64_t end) {
fill_stub(total_cnts + begin, 0, end - begin);
});
// align experts index
int64_t num_tokens_post_pad = moe_align_block_size<BLOCK_M>(
sorted_ids, expert_ids, topk_ids.data_ptr<int32_t>(), total_cnts, cumsums, offsets, E, numel, num_threads);
// unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches:
// 1. intermediate_cache1 : [M * topk, N]
// 2. intermediate_cache2 : [M * topk, K]
// 3. A_tmp : [T, BLOCK_M * K]
// 4. C_tmp : [T, 2 * BLOCK_M * BLOCK_N]
//
// for int8 w8a8:
// 5. Aq_tmp : [M, K] or [M * topk, N]
// 6. As_tmp : [M * topk]
//
int64_t buffer_size_nbytes = M * topk * N * 2 + M * topk * K * 2 +
num_threads * BLOCK_M * K * (use_int8_w8a8 ? 1 : 2) +
num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
if (use_int8_w8a8) {
buffer_size_nbytes += std::max(M * K, M * topk * N) + M * topk * sizeof(float);
}
auto buffer2 = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "fused_experts_kernel_impl", [&] {
scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer2.data_ptr<int8_t>()));
scalar_t* __restrict__ intermediate_cache2 = intermediate_cache1 + M * topk * N;
if (use_int8_w8a8) {
uint8_t* __restrict__ A_tmp = (uint8_t*)((void*)(intermediate_cache2 + M * topk * K));
float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * topk * N)));
auto w1s = w1_scale.value();
auto w2s = w2_scale.value();
TORCH_CHECK(w1s.numel() == E * 2 * N);
TORCH_CHECK(w2s.numel() == E * K);
fused_experts_int8_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache1,
intermediate_cache2,
A_tmp,
C_tmp,
Aq_tmp,
As_tmp,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<int8_t>(),
packed_w2.data_ptr<int8_t>(),
w1s.data_ptr<float>(),
w2s.data_ptr<float>(),
topk_weights.data_ptr<float>(),
sorted_ids,
expert_ids,
offsets,
M,
N,
K,
E,
topk,
num_tokens_post_pad);
} else {
scalar_t* __restrict__ A_tmp = intermediate_cache2 + M * topk * K;
float* __restrict__ C_tmp = (float*)((void*)(A_tmp + num_threads * BLOCK_M * K));
fused_experts_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache1,
intermediate_cache2,
A_tmp,
C_tmp,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<scalar_t>(),
packed_w2.data_ptr<scalar_t>(),
topk_weights.data_ptr<float>(),
sorted_ids,
expert_ids,
offsets,
M,
N,
K,
E,
topk,
num_tokens_post_pad);
}
});
return out_hidden_states;
}
// shared expert kernel
//
// hidden_states: [M, K]
// w1: [2N, K]
// w2: [K, N]
// fused_experts_out
at::Tensor shared_expert_cpu(
at::Tensor& hidden_states,
at::Tensor& w1,
at::Tensor& w2,
at::Tensor& fused_experts_out,
double routed_scaling_factor,
bool inplace,
bool use_int8_w8a8,
std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale,
std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::shared_expert_cpu", std::vector<c10::IValue>({hidden_states, w1, w2}));
auto packed_w1 = is_vnni ? w1 : convert_weight_packed(w1);
auto packed_w2 = is_vnni ? w2 : convert_weight_packed(w2);
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const auto st = hidden_states.scalar_type();
CHECK_INPUT(hidden_states);
CHECK_INPUT(fused_experts_out);
CHECK_INPUT(w1);
CHECK_INPUT(w2);
CHECK_DIM(2, hidden_states);
CHECK_DIM(2, w1);
CHECK_DIM(2, w2);
CHECK_EQ(hidden_states.sizes(), fused_experts_out.sizes());
CHECK_EQ(hidden_states.scalar_type(), st);
int64_t M = hidden_states.size(0);
int64_t K = hidden_states.size(1);
int64_t N = w1.size(0) / 2;
// we use int32_t compensation for int8 w8a8
int64_t packed_K = get_row_size(K, use_int8_w8a8);
int64_t packed_N = get_row_size(N, use_int8_w8a8);
// check weight shapes
CHECK_EQ(w2.size(0), K);
CHECK_EQ(packed_w1.size(1), packed_K);
CHECK_EQ(packed_w2.size(1), packed_N);
if (use_int8_w8a8) {
TORCH_CHECK(w1_scale.has_value(), "missing w1_scale for int8 w8a8.");
TORCH_CHECK(w2_scale.has_value(), "missing w2_scale for int8 w8a8.");
TORCH_CHECK(!a1_scale.has_value(), "static quantization for activation not supported.");
TORCH_CHECK(!a2_scale.has_value(), "static quantization for activation not supported.");
}
at::Tensor out_hidden_states = inplace ? hidden_states : at::empty_like(hidden_states);
// unlike triton kernel, we fuse silu with gemm1 so only need 2 intermediate_caches:
// 1. intermediate_cache1 : [M, N]
// 2. C_tmp : [T, 2 * BLOCK_M * BLOCK_N]
//
// for int8 w8a8:
// 3. Aq_tmp : [M, K] or [M, N]
// 4. As_tmp : [M]
//
int num_threads = at::get_num_threads();
int64_t buffer_size_nbytes = M * N * 2 + num_threads * 2 * BLOCK_M * BLOCK_N * sizeof(float);
if (use_int8_w8a8) {
buffer_size_nbytes += std::max(M * K, M * N) + M * sizeof(float);
}
auto buffer = at::empty({buffer_size_nbytes}, hidden_states.options().dtype(at::kChar));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "share_experts_kernel_impl", [&] {
scalar_t* __restrict__ intermediate_cache1 = (scalar_t*)((void*)(buffer.data_ptr<int8_t>()));
float* __restrict__ C_tmp = (float*)((void*)(intermediate_cache1 + M * N));
if (use_int8_w8a8) {
uint8_t* __restrict__ Aq_tmp = (uint8_t*)((void*)(C_tmp + num_threads * 2 * BLOCK_M * BLOCK_N));
float* __restrict__ As_tmp = (float*)((void*)(Aq_tmp + std::max(M * K, M * N)));
auto w1s = w1_scale.value();
auto w2s = w2_scale.value();
TORCH_CHECK(w1s.numel() == 2 * N);
TORCH_CHECK(w2s.numel() == K);
shared_expert_int8_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache1,
C_tmp,
Aq_tmp,
As_tmp,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<int8_t>(),
packed_w2.data_ptr<int8_t>(),
w1s.data_ptr<float>(),
w2s.data_ptr<float>(),
fused_experts_out.data_ptr<scalar_t>(),
routed_scaling_factor,
M,
N,
K);
} else {
shared_expert_kernel_impl<scalar_t>(
out_hidden_states.data_ptr<scalar_t>(),
intermediate_cache1,
C_tmp,
hidden_states.data_ptr<scalar_t>(),
packed_w1.data_ptr<scalar_t>(),
packed_w2.data_ptr<scalar_t>(),
fused_experts_out.data_ptr<scalar_t>(),
routed_scaling_factor,
M,
N,
K);
}
});
return out_hidden_states;
}
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t size) {
using Vec = at::vec::Vectorized<scalar_t>;
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += Vec::size()) {
Vec data = Vec::loadu(input + d);
data.store(out + d);
}
}
template <>
inline void copy_stub<uint8_t>(uint8_t* __restrict__ out, const uint8_t* __restrict__ input, int64_t size) {
// size might be 64x + 32
std::memcpy(out, input, size * sizeof(uint8_t));
}
template <typename scalar_t>
inline void copy_mul_stub(scalar_t* __restrict__ out, const float* __restrict__ input, float weight, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec weight_vec = fVec(weight);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) * weight_vec;
fVec data1 = fVec::loadu(input + d + fVec::size()) * weight_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] * weight);
}
}
// acc from [topk, K] to [K]
template <typename scalar_t>
inline void sum_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ input, int64_t topk, int64_t K) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
if (topk == 1) {
// do copy for topk = 1
copy_stub(out, input, K);
} else {
// do sum for topk != 1
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= K - kVecSize; d += kVecSize) {
fVec sum_fvec0 = fVec(0.f);
fVec sum_fvec1 = fVec(0.f);
for (int t = 0; t < topk; ++t) {
bVec x_bvec = bVec::loadu(input + t * K + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec0 += x_fvec0;
sum_fvec1 += x_fvec1;
}
bVec out_bvec = convert_from_float_ext<scalar_t>(sum_fvec0, sum_fvec1);
out_bvec.store(out + d);
}
for (; d < K; ++d) {
float sum_val = 0.f;
for (int t = 0; t < topk; ++t) {
sum_val += static_cast<float>(input[t * K + d]);
}
out[d] = static_cast<scalar_t>(sum_val);
}
}
}
// out = input + input2 * scale
template <typename scalar_t>
inline void add_mul_stub(
scalar_t* __restrict__ out,
const float* __restrict__ input,
const scalar_t* __restrict__ input2,
float scale,
int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
const fVec s_vec = fVec(scale);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec x0 = fVec::loadu(input + d);
fVec x1 = fVec::loadu(input + d + fVec::size());
bVec y_bvec = bVec::loadu(input2 + d);
fVec y0, y1;
std::tie(y0, y1) = at::vec::convert_to_float(y_bvec);
x0 = x0 + y0 * s_vec;
x1 = x1 + y1 * s_vec;
bVec out_vec = convert_from_float_ext<scalar_t>(x0, x1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + float(input2[d]) * scale);
}
}
/// gemm for w13
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B0,
const int8_t* __restrict__ B1,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
const int32_t* __restrict__ Bcomp0,
const int32_t* __restrict__ Bcomp1,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni<at::BFloat16, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B0,
const int8_t* __restrict__ B1,
at::BFloat16* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
const int32_t* __restrict__ Bcomp0,
const int32_t* __restrict__ Bcomp1,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512i va;
__m512i vb0[COLS];
__m512i vb1[COLS];
__m512i vc0[ROWS * COLS];
__m512i vc1[ROWS * COLS];
__m512i vcomp0[COLS];
__m512i vcomp1[COLS];
__m512 vas;
__m512 vbs0[COLS];
__m512 vbs1[COLS];
auto loadc = [&](auto i) {
vc0[i] = _mm512_set1_epi32(0);
vc1[i] = _mm512_set1_epi32(0);
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b0_ptr = reinterpret_cast<const int32_t*>(B0);
const int32_t* b1_ptr = reinterpret_cast<const int32_t*>(B1);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb0[col] = _mm512_loadu_si512(b0_ptr + k * ldb4 + col * 16);
vb1[col] = _mm512_loadu_si512(b1_ptr + k * ldb4 + col * 16);
}
vc0[i] = _mm512_dpbusd_epi32(vc0[i], va, vb0[col]);
vc1[i] = _mm512_dpbusd_epi32(vc1[i], va, vb1[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto scalec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr (col == 0) {
vas = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp
if constexpr (row == 0) {
vbs0[col] = _mm512_loadu_ps(Bs0 + col * 16);
vbs1[col] = _mm512_loadu_ps(Bs1 + col * 16);
vcomp0[col] = _mm512_loadu_si512(Bcomp0 + col * 16);
vcomp1[col] = _mm512_loadu_si512(Bcomp1 + col * 16);
}
__m512 c0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc0[i], vcomp0[col]));
__m512 c1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc1[i], vcomp1[col]));
vc0[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c0, vas), vbs0[col]));
vc1[i] = _mm512_castps_si512(_mm512_mul_ps(_mm512_mul_ps(c1, vas), vbs1[col]));
};
Unroll<ROWS * COLS>{}(scalec);
using Vec = at::vec::Vectorized<float>;
const Vec one = Vec(1.f);
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
Vec x0 = _mm512_castsi512_ps(vc0[row * COLS + col + 0]);
Vec x1 = _mm512_castsi512_ps(vc0[row * COLS + col + 1]);
Vec y0 = _mm512_castsi512_ps(vc1[row * COLS + col + 0]);
Vec y1 = _mm512_castsi512_ps(vc1[row * COLS + col + 1]);
// silu
x0 = x0 / (one + x0.neg().exp_u20());
x1 = x1 / (one + x1.neg().exp_u20());
// mul
x0 = x0 * y0;
x1 = x1 * y1;
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(__m512(x1), __m512(x0))));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_VNNI(MB_SIZE, NB_SIZE) \
tinygemm_kernel_vnni<scalar_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B0 + nb_start * 4, \
B1 + nb_start * 4, \
C + mb_start * ldc + nb_start, \
As + mb_start, \
Bs0 + nb_start, \
Bs1 + nb_start, \
Bcomp0 + nb_start, \
Bcomp1 + nb_start, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B0,
const int8_t* __restrict__ B1,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
const int32_t* Bcomp0 = reinterpret_cast<const int32_t*>(B0 + block_size_n() * K);
const int32_t* Bcomp1 = reinterpret_cast<const int32_t*>(B1 + block_size_n() * K);
// pattern: 1-(2+2)-(8+8)
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 32;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
case 0x12:
LAUNCH_TINYGEMM_KERNEL_VNNI(1, 32);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_VNNI(2, 32);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_VNNI(3, 32);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_VNNI(4, 32);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
/// gemm for w2
template <typename scalar_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni2 {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
float* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_vnni2<at::BFloat16, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
float* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
__m512i va;
__m512i vb[COLS];
__m512i vc[ROWS * COLS];
__m512i vcomp[COLS];
__m512 vas;
__m512 vbs[COLS];
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
}
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr (col == 0) {
vas = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp per 2 vectors
// also load bias if any
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
vbs[col + 0] = _mm512_loadu_ps(Bs + col * 16);
vbs[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
}
}
__m512 x = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[i], vcomp[col]));
x = _mm512_mul_ps(_mm512_mul_ps(x, vas), vbs[col]);
_mm512_storeu_ps(reinterpret_cast<__m512*>(C + row * ldc + col * 16), x);
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_VNNI2(MB_SIZE, NB_SIZE) \
tinygemm_kernel_vnni2<scalar_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 4, \
C + mb_start * ldc + nb_start, \
As + mb_start, \
Bs + nb_start, \
Bcomp + nb_start, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
float* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
// B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
case 0x12:
LAUNCH_TINYGEMM_KERNEL_VNNI2(1, 32);
break;
case 0x22:
LAUNCH_TINYGEMM_KERNEL_VNNI2(2, 32);
break;
case 0x32:
LAUNCH_TINYGEMM_KERNEL_VNNI2(3, 32);
break;
case 0x42:
LAUNCH_TINYGEMM_KERNEL_VNNI2(4, 32);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
} // anonymous namespace
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad) {
// handle 2 tiles per block
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 0: quantize input to uint8, [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_tmp + m * K, As_tmp[m], input + m * K, K);
}
});
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
const int64_t MB = div_up(num_tokens_post_pad, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// strides for w1: [E, 2N, K]
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
// K and N are packed for int8
const int64_t packed_K = get_row_size<int8_t>(K);
const int64_t packed_N = get_row_size<int8_t>(N);
const int64_t stride_e = 2 * N * packed_K;
const int64_t stride_n = packed_K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
uint8_t* __restrict__ A = A_tmp + tid * BLOCK_M * K;
alignas(64) float As[BLOCK_M];
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
// B shape [K, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const int8_t* __restrict__ B0 = packed_w1 + expert_id * stride_e + nb0 * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + expert_id * stride_e + nb1 * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + expert_id * 2 * N + nb0 * BLOCK_N;
const float* __restrict__ Bs1 = w1s + expert_id * 2 * N + nb1 * BLOCK_N;
// 1.a load A
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
int64_t m_size = offsets[mb + 1] - offsets[mb];
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m] / topk;
copy_stub(A + m * K, Aq_tmp + index * K, K);
As[m] = As_tmp[index];
}
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
const int64_t offset = offsets[mb];
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + offset * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
at::parallel_for(0, M * topk, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [E, K, N] as [E, OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
const int64_t stride_e2 = OC * packed_N;
const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = offsets[mb + 1] - offsets[mb];
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// A ptr from ic1 of [M * topk, N] in sorted order
// so as to avoid copy A to tmp buffer again
const uint8_t* __restrict__ A = Aq_tmp + offsets[mb] * N;
const float* __restrict__ As = As_tmp + offsets[mb];
const int32_t* A_ids = sorted_ids + mb * BLOCK_M;
// B shape [IC, n_size] in vnni format
int32_t expert_id = expert_ids[mb];
const int8_t* __restrict__ B = packed_w2 + expert_id * stride_e2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs = w2s + expert_id * K + nb * BLOCK_N;
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
// 2.b copy from C to ic2 in original order
// and also mul topk_weights in float32
for (int64_t m = 0; m < m_size; ++m) {
int32_t index = A_ids[m];
float weight = topk_weights[index];
copy_mul_stub(ic2 + index * K + nb * BLOCK_N, C + m * BLOCK_N, weight, n_size);
}
}
});
// stage 3: out = intermediate_cache2.sum(dim=1)
// from [M, topk, K] to [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
sum_stub(output + m * K, ic2 + m * topk * K, topk, K);
}
});
}
#define INSTANTIATE_MOE_INT8_TEMPLATE(TYPE) \
template void fused_experts_int8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic1, \
TYPE* __restrict__ ic2, \
uint8_t* __restrict__ A_tmp, \
float* __restrict__ C_tmp, \
uint8_t* __restrict__ Aq_tmp, \
float* __restrict__ As_tmp, \
const TYPE* __restrict__ input, \
const int8_t* __restrict__ packed_w1, \
const int8_t* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
const float* __restrict__ topk_weights, \
const int32_t* __restrict__ sorted_ids, \
const int32_t* __restrict__ expert_ids, \
const int32_t* __restrict__ offsets, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t E, \
int64_t topk, \
int64_t num_tokens_post_pad)
INSTANTIATE_MOE_INT8_TEMPLATE(at::BFloat16);
INSTANTIATE_MOE_INT8_TEMPLATE(at::Half);
template <typename scalar_t>
void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K) {
// handle 2 tiles per block
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
// stage 0: quantize input to uint8, [M, K]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_tmp + m * K, As_tmp[m], input + m * K, K);
}
});
// stage 1: intermediate_cache1 = silu(hidden_states @ w1)
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
TORCH_CHECK(N % BLOCK_N == 0, "Fixme when N is not multiples of ", BLOCK_N);
// K and N are packed for int8
const int64_t packed_K = get_row_size<int8_t>(K);
const int64_t packed_N = get_row_size<int8_t>(N);
const int64_t stride_n = packed_K;
// here we only parallel on half of 2N to fuse silu_and_mul with gemm
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB;
int64_t nb = i % NB;
// nb0 from top half and nb1 from bottom half
int64_t nb0 = nb, nb1 = nb + NB;
int64_t n_size = std::min(N - nb0 * BLOCK_N, BLOCK_N);
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
// A shape [m_size, K]
const uint8_t* A = Aq_tmp + mb * BLOCK_M * K;
const float* As = As_tmp + mb * BLOCK_M;
// B shape [K, n_size] in vnni format
const int8_t* __restrict__ B0 = packed_w1 + nb0 * BLOCK_N * stride_n;
const int8_t* __restrict__ B1 = packed_w1 + nb1 * BLOCK_N * stride_n;
const float* __restrict__ Bs0 = w1s + nb0 * BLOCK_N;
const float* __restrict__ Bs1 = w1s + nb1 * BLOCK_N;
// fused 1.b: silu_and_mul(A @ B0, A @ B1)
tinygemm_kernel(
/* A */ A,
/* B0 */ B0,
/* B1 */ B1,
/* C */ ic1 + mb * BLOCK_M * N + nb * BLOCK_N,
/* As */ As,
/* Bs0 */ Bs0,
/* Bs1 */ Bs1,
/* M */ m_size,
/* N */ n_size,
/* K */ K,
/* lda */ K,
/* ldb */ n_size,
/* ldc */ N);
}
});
// stage 1.5: quantize ic1 to uint8, [M * topk, N]
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_tmp + m * N, As_tmp[m], ic1 + m * N, N);
}
});
// stage 2: intermediate_cache2 = intermediate_cache1 @ w2
// w2 : [K, N] as [OC, IC]
const int64_t OC = K; // rename K as OC
const int64_t IC = N; // rename N as IC
const int64_t MB2 = MB;
const int64_t NB2 = div_up(OC, BLOCK_N);
const int64_t stride_oc = packed_N;
// parallel on [MB2, NB2]
at::parallel_for(0, MB2 * NB2, 0, [&](int64_t begin, int64_t end) {
// get local pointers
int tid = at::get_thread_num();
// we won't be using C1 for gemm2
float* __restrict__ C = C_tmp + tid * 2 * BLOCK_M * BLOCK_N;
for (int64_t i = begin; i < end; ++i) {
int64_t mb = i / NB2;
int64_t nb = i % NB2;
int64_t m_size = std::min(M - mb * BLOCK_M, BLOCK_M);
int64_t n_size = std::min(OC - nb * BLOCK_N, BLOCK_N);
// A shape [m_size, IC]
const uint8_t* __restrict__ A = Aq_tmp + mb * BLOCK_M * N;
const float* __restrict__ As = As_tmp + mb * BLOCK_M;
// B shape [IC, n_size] in vnni format
const int8_t* __restrict__ B = packed_w2 + nb * BLOCK_N * stride_oc;
const float* __restrict__ Bs = w2s + nb * BLOCK_N;
// 2.a gemm: C = A @ B
tinygemm_kernel<scalar_t>(
/* A */ A,
/* B */ B,
/* C */ C,
/* As */ As,
/* Bs */ Bs,
/* M */ m_size,
/* N */ n_size,
/* K */ IC,
/* lda */ IC,
/* ldb */ n_size,
/* ldc */ BLOCK_N);
// 2.b copy from C to output and add fused_experts_out
scalar_t* __restrict__ out = output + mb * BLOCK_M * K + nb * BLOCK_N;
const scalar_t* __restrict__ fused_out = fused_experts_out + mb * BLOCK_M * K + nb * BLOCK_N;
for (int64_t m = 0; m < m_size; ++m) {
add_mul_stub(out + m * K, C + m * BLOCK_N, fused_out + m * K, routed_scaling_factor, n_size);
}
}
});
}
#define INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(TYPE) \
template void shared_expert_int8_kernel_impl<TYPE>( \
TYPE* __restrict__ output, \
TYPE* __restrict__ ic1, \
float* __restrict__ C_tmp, \
uint8_t* __restrict__ Aq_tmp, \
float* __restrict__ As_tmp, \
const TYPE* __restrict__ input, \
const int8_t* __restrict__ packed_w1, \
const int8_t* __restrict__ packed_w2, \
const float* __restrict__ w1s, \
const float* __restrict__ w2s, \
const TYPE* __restrict__ fused_experts_out, \
float routed_scaling_factor, \
int64_t M, \
int64_t N, \
int64_t K)
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::BFloat16);
INSTANTIATE_SHARED_EXPERT_INT8_TEMPLATE(at::Half);
#include "common.h"
#include "vec.h"
namespace {
// NB: avoid using `at::vec::map<>` on bfloat16 or half
template <typename scalar_t>
void rmsnorm_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ weight,
int64_t batch_size,
int64_t hidden_size,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
sum_val += x_val * x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec w_bvec = bVec::loadu(weight + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(out_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float w_val = static_cast<float>(weight[d]);
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * w_val);
}
}
});
}
template <typename scalar_t>
void fused_add_rmsnorm_kernel_impl(
scalar_t* __restrict__ input,
scalar_t* __restrict__ residual,
const scalar_t* __restrict__ weight,
float* __restrict__ buffer,
int64_t batch_size,
int64_t hidden_size,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
float* __restrict__ buffer_ptr = buffer + tid * hidden_size;
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ input_ptr = input + i * hidden_size;
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec r_bvec = bVec::loadu(residual_ptr + d);
fVec r_fvec0, r_fvec1;
std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);
x_fvec0 += r_fvec0;
x_fvec1 += r_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(residual_ptr + d);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
x_fvec0.store(buffer_ptr + d);
x_fvec1.store(buffer_ptr + d + fVec::size());
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float r_val = static_cast<float>(residual_ptr[d]);
x_val += r_val;
residual_ptr[d] = static_cast<scalar_t>(x_val);
sum_val += x_val * x_val;
buffer_ptr[d] = x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
bVec w_bvec = bVec::loadu(weight + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(input_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = buffer_ptr[d] * rsqrt_var * static_cast<float>(weight[d]);
input_ptr[d] = x_val;
}
}
});
}
} // anonymous namespace
// input : {batch_size, hidden_size}
// weight: {hidden_size}
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
CHECK_INPUT(input);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
CHECK_DIM(1, weight);
CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
rmsnorm_kernel_impl<scalar_t>(
output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
batch_size,
hidden_size,
eps);
});
return output;
}
// input : {batch_size, hidden_size}
// residual: {batch_size, hidden_size}
// weight : {hidden_size}
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
CHECK_DIM(2, residual);
CHECK_DIM(1, weight);
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
// allocate temp buffer to store x in float32 per thread
// TODO: implement a singleton for context
int64_t num_threads = at::get_num_threads();
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] {
fused_add_rmsnorm_kernel_impl<scalar_t>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
buffer.data_ptr<float>(),
batch_size,
hidden_size,
eps);
});
}
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
// [NOTE]: Fused kernel for QKV projection with weight absorption and RoPE
//
// 1. `q_a_proj` and `kv_a_proj_with_mqa` fused into one gemm,
// otherwise we need to split IC for the 2nd gemm.
// 2. `q_a_layernorm` and `kv_a_layernorm` fused into one parallel loop.
// 3. k_input and v_input share the same storage, the torch API did
// this in `set_kv_buffer`. No additional memory movement.
//
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B0,
const scalar_t* __restrict__ B1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K) {
// convert_weight_packed make sure N0 and N1 are 32x
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const scalar_t* __restrict__ B = nb < NB0 ? B0 : B1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * K /* nb * BLOCK_N * K */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
// [C0, C1] = A @ [B0, B1]
template <typename scalar_t>
void segment_gemm_kernel_impl(
scalar_t* __restrict__ C0,
scalar_t* __restrict__ C1,
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B0,
const int8_t* __restrict__ B1,
const float* __restrict__ As,
const float* __restrict__ Bs0,
const float* __restrict__ Bs1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB0 = div_up(N0, BLOCK_N);
const int64_t NB1 = div_up(N1, BLOCK_N);
const int64_t NB = NB0 + NB1;
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const bool use_brgemm = false;
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
// parallel on [MB, NB0 + NB1]
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = BLOCK_N;
const int8_t* __restrict__ B = nb < NB0 ? B0 : B1;
const float* __restrict__ Bs = nb < NB0 ? Bs0 : Bs1;
scalar_t* __restrict__ C = nb < NB0 ? C0 : C1;
int64_t ldc = nb < NB0 ? N0 : N1;
int64_t local_nb_start = nb < NB0 ? nb_start : nb_start - N0;
tinygemm_kernel<scalar_t>(
/* A */ A + mb_start * K,
/* B */ B + local_nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ C + mb_start * ldc + local_nb_start,
/* Ctmp*/ Ctmp,
/* As */ As + mb_start,
/* Bs */ Bs + local_nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ ldc,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
template <typename scalar_t>
inline float reduce(const scalar_t* __restrict__ x, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
fVec sum_fvec = fVec(float(0));
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x_bvec = bVec::loadu(x + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
return vec_reduce_sum(sum_fvec);
}
// map2 from aten functional doesn't have fast bf16->fp32 conversion
template <typename scalar_t>
inline void map2(scalar_t* y, const scalar_t* x, const scalar_t* __restrict__ w, float scale, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
fVec scale_fvec = fVec(scale);
// no remainder
#pragma GCC unroll 4
for (int64_t d = 0; d < size; d += bVec::size()) {
bVec x_bvec = bVec::loadu(x + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec w_bvec = bVec::loadu(w + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(y + d);
}
}
template <typename scalar_t>
void rms_norm_kernel_impl(
scalar_t* __restrict__ input0,
scalar_t* __restrict__ input1,
const scalar_t* __restrict__ weight0,
const scalar_t* __restrict__ weight1,
int64_t M,
int64_t N0,
int64_t N1,
int64_t stride1,
float eps = 1e-5) {
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
scalar_t* x0 = input0 + m * N0;
scalar_t* x1 = input1 + m * stride1;
float scale0 = reduce(x0, N0);
float scale1 = reduce(x1, N1);
scale0 = float(1) / std::sqrt(scale0 / N0 + eps);
scale1 = float(1) / std::sqrt(scale1 / N1 + eps);
map2(x0, x0, weight0, scale0, N0);
map2(x1, x1, weight1, scale1, N1);
}
});
}
template <typename scalar_t>
inline void rotary(const scalar_t* input, scalar_t* out, const scalar_t* cos, const scalar_t* sin, int64_t size) {
TORCH_CHECK(false, "rotary scalar path not implemented.");
}
#if defined(CPU_CAPABILITY_AVX512)
template <>
inline void rotary<at::BFloat16>(
const at::BFloat16* input, at::BFloat16* out, const at::BFloat16* cos, const at::BFloat16* sin, int64_t size) {
// permute indices
const __m512i idx1 = _mm512_set_epi32(30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2, 0);
const __m512i idx2 = _mm512_set_epi32(31, 29, 27, 25, 23, 21, 19, 17, 15, 13, 11, 9, 7, 5, 3, 1);
const __m512i idy1 = _mm512_set_epi32(23, 7, 22, 6, 21, 5, 20, 4, 19, 3, 18, 2, 17, 1, 16, 0);
const __m512i idy2 = _mm512_set_epi32(31, 15, 30, 14, 29, 13, 28, 12, 27, 11, 26, 10, 25, 9, 24, 8);
// rotary dim is 64, just 2 iters
#pragma GCC unroll 2
for (int64_t d = 0; d < size; d += 32) {
int64_t d2 = d >> 1;
// load coefs
__m512 vcos = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(cos + d2)));
__m512 vsin = CVT_BF16_TO_FP32(_mm256_loadu_si256(reinterpret_cast<const __m256i*>(sin + d2)));
// load input
__m512i a16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(input + d));
__m512 a = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
__m512 b = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
// from [16, 2] to [2, 16]
__m512 in1 = _mm512_mask_permutex2var_ps(a, 0xffff, idx1, b);
__m512 in2 = _mm512_mask_permutex2var_ps(a, 0xffff, idx2, b);
// out1 = in1 * cos - in2 * sin;
// out2 = in2 * cos + in1 * sin
__m512 out1 = _mm512_sub_ps(_mm512_mul_ps(in1, vcos), _mm512_mul_ps(in2, vsin));
__m512 out2 = _mm512_add_ps(_mm512_mul_ps(in2, vcos), _mm512_mul_ps(in1, vsin));
// from [2, 16] to [16, 2]
a = _mm512_mask_permutex2var_ps(out1, 0xffff, idy1, out2);
b = _mm512_mask_permutex2var_ps(out1, 0xffff, idy2, out2);
_mm512_storeu_si512(reinterpret_cast<__m512i*>((out + d)), (__m512i)(_mm512_cvtne2ps_pbh(b, a)));
}
}
#endif
template <typename scalar_t>
void rotary_emb_kernel_impl(
scalar_t* q_pe_out,
scalar_t* k_pe_out,
const scalar_t* q_pe,
const scalar_t* k_pe,
const int64_t* pos,
const scalar_t* cos_sin,
int64_t num_seqs,
int64_t num_heads,
int64_t rotary_dim,
int64_t q_strideB,
int64_t q_strideH,
int64_t k_strideB,
int64_t oq_strideB,
int64_t oq_strideH,
int64_t ok_strideB) {
TORCH_CHECK(rotary_dim % 32 == 0, "rotary_dim is not 32x.");
const int64_t rotary_offset = rotary_dim / 2;
// parallel on [num_seqs, num_heads + 1]
// top [num_heads] handle q_pe and bottom [1] handle k_pe
at::parallel_for(0, num_seqs * (num_heads + 1), GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t seq{0}, head_id{0};
data_index_init(begin, seq, num_seqs, head_id, num_heads + 1);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
// get cos and sin cache ptr
int64_t index = pos[seq];
const scalar_t* cos = cos_sin + index * rotary_dim;
const scalar_t* sin = cos + rotary_offset;
const scalar_t* input =
(head_id < num_heads) ? q_pe + seq * q_strideB + head_id * q_strideH : k_pe + seq * k_strideB;
scalar_t* out =
(head_id < num_heads) ? q_pe_out + seq * oq_strideB + head_id * oq_strideH : k_pe_out + seq * ok_strideB;
rotary<scalar_t>(input, out, cos, sin, rotary_dim);
// move to the next index
data_index_step(seq, num_seqs, head_id, num_heads + 1);
}
});
}
} // anonymous namespace
extern at::Tensor
weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
extern at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
extern void
bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
// NB: shapes in DeepDeek R1
//
// hidden_states : [num_seqs, hidden_size] [1, 7168]
// q_a_proj_weight : [q_lora_rank, hidden_size] [1536, 7168]
// q_b_proj_weight : [num_heads * qk_head_dim, q_lora_rank] [4224, 1536]
// kv_a_proj_weight : [kv_lora_rank + qk_rope_head_dim, hidden_size] [576, 7168]
// w_kc : [num_heads, kv_lora_rank, qk_nope_head_dim] [22, 512, 128]
// q_a_layernorm_weight : [q_lora_rank] [1536]
// kv_a_layernorm_weight : [kv_lora_rank] [512]
//
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& hidden_states,
at::Tensor& q_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& kv_a_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
std::optional<at::Tensor>& q_a_proj_scale,
std::optional<at::Tensor>& q_b_proj_scale,
std::optional<at::Tensor>& kv_a_proj_scale,
bool is_vnni) {
RECORD_FUNCTION(
"sgl-kernel::qkv_proj_with_rope",
std::vector<c10::IValue>({hidden_states, q_a_proj_weight, q_b_proj_weight, kv_a_proj_weight, w_kc}));
const auto st = hidden_states.scalar_type();
CHECK_INPUT(hidden_states);
CHECK_INPUT(positions);
CHECK_INPUT(cos_sin_cache);
CHECK_EQ(q_a_layernorm_weight.scalar_type(), st);
CHECK_EQ(kv_a_layernorm_weight.scalar_type(), st);
CHECK_EQ(positions.scalar_type(), at::kLong);
CHECK_EQ(cos_sin_cache.scalar_type(), st);
CHECK_DIM(2, hidden_states);
CHECK_DIM(3, w_kc);
CHECK_DIM(1, q_a_layernorm_weight);
CHECK_DIM(1, kv_a_layernorm_weight);
CHECK_DIM(1, positions);
CHECK_DIM(2, cos_sin_cache);
// skip contiguous checks for weights, expect prepacked
TORCH_CHECK(is_vnni, "qkv_proj_with_rope: expect weights are prepacked!");
int64_t num_seqs = hidden_states.size(0);
int64_t hidden_size = hidden_states.size(1);
int64_t q_lora_rank = q_a_proj_weight.size(0);
int64_t num_heads = w_kc.size(0);
int64_t kv_lora_rank = w_kc.size(1);
int64_t qk_head_dim = q_b_proj_weight.size(0) / num_heads;
int64_t qk_nope_head_dim = w_kc.size(2);
int64_t qk_rope_head_dim = kv_a_proj_weight.size(0) - kv_lora_rank;
int64_t rotary_dim = cos_sin_cache.size(1);
CHECK_EQ(positions.numel(), num_seqs);
CHECK_EQ(rotary_dim, qk_rope_head_dim);
CHECK_EQ(q_a_layernorm_weight.numel(), q_lora_rank);
CHECK_EQ(kv_a_layernorm_weight.numel(), kv_lora_rank);
// check the packed dimension
CHECK_EQ(q_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
CHECK_EQ(q_b_proj_weight.size(1), get_row_size(q_lora_rank, use_int8_w8a8));
CHECK_EQ(kv_a_proj_weight.size(1), get_row_size(hidden_size, use_int8_w8a8));
if (use_int8_w8a8) {
TORCH_CHECK(q_a_proj_scale.has_value(), "missing q_a_proj_scale for int8 w8a8.");
TORCH_CHECK(q_b_proj_scale.has_value(), "missing q_b_proj_scale for int8 w8a8.");
TORCH_CHECK(kv_a_proj_scale.has_value(), "missing kv_a_proj_scale for int8 w8a8.");
}
// outputs and temp buffer
const auto options = hidden_states.options();
auto q_input = at::empty({num_seqs, num_heads, kv_lora_rank + qk_rope_head_dim}, options);
auto k_input = at::empty({num_seqs, 1, kv_lora_rank + qk_rope_head_dim}, options);
auto v_input = k_input.narrow(-1, 0, kv_lora_rank);
// outputs of q_a_proj and q_b_proj
auto qa = at::empty({num_seqs, q_lora_rank}, options);
// stage 1: q_a_proj and kv_a_proj
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "qkv_proj_kernel_impl", [&] {
if (use_int8_w8a8) {
auto q_a_proj_s = q_a_proj_scale.value();
auto kv_a_proj_s = kv_a_proj_scale.value();
TORCH_CHECK(q_a_proj_s.numel() == q_lora_rank);
TORCH_CHECK(kv_a_proj_s.numel() == kv_lora_rank + qk_rope_head_dim);
auto buffer = at::empty({num_seqs * hidden_size + num_seqs * 4}, options.dtype(at::kByte));
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + num_seqs * hidden_size));
const scalar_t* __restrict__ A_data = hidden_states.data_ptr<scalar_t>();
at::parallel_for(0, num_seqs, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * hidden_size, As_data[m], A_data + m * hidden_size, hidden_size);
}
});
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
Aq_data,
q_a_proj_weight.data_ptr<int8_t>(),
kv_a_proj_weight.data_ptr<int8_t>(),
As_data,
q_a_proj_s.data_ptr<float>(),
kv_a_proj_s.data_ptr<float>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size);
} else {
segment_gemm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
k_input.data_ptr<scalar_t>(),
hidden_states.data_ptr<scalar_t>(),
q_a_proj_weight.data_ptr<scalar_t>(),
kv_a_proj_weight.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank + qk_rope_head_dim,
hidden_size);
}
});
// stage 2: apply rmsnorm inplace
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rms_norm_kernel_impl", [&] {
rms_norm_kernel_impl<scalar_t>(
qa.data_ptr<scalar_t>(),
v_input.data_ptr<scalar_t>(),
q_a_layernorm_weight.data_ptr<scalar_t>(),
kv_a_layernorm_weight.data_ptr<scalar_t>(),
num_seqs,
q_lora_rank,
kv_lora_rank,
kv_lora_rank + qk_rope_head_dim,
eps);
});
// stage 3: q_b_proj
at::Tensor qb;
std::optional<at::Tensor> bias;
if (use_int8_w8a8) {
qb = int8_scaled_mm_with_quant(qa, q_b_proj_weight, q_b_proj_scale.value(), bias, at::kBFloat16, is_vnni);
} else {
qb = weight_packed_linear(qa, q_b_proj_weight, bias, is_vnni);
}
qb.as_strided_({num_seqs, num_heads, qk_head_dim}, {num_heads * qk_head_dim, qk_head_dim, 1});
// stage 4: bmm
std::optional<at::Tensor> scale;
auto q_nope = qb.narrow(2, 0, qk_nope_head_dim).transpose_(0, 1);
auto q_nope_out = q_input.narrow(2, 0, kv_lora_rank).transpose_(0, 1);
bmm_cpu(q_nope_out, q_nope, w_kc, is_vnni, scale);
// stage 5: rope
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "rotary_emb_kernel_impl", [&] {
rotary_emb_kernel_impl<scalar_t>(
q_input.data_ptr<scalar_t>() + kv_lora_rank,
k_input.data_ptr<scalar_t>() + kv_lora_rank,
qb.data_ptr<scalar_t>() + qk_nope_head_dim,
k_input.data_ptr<scalar_t>() + kv_lora_rank,
positions.data_ptr<int64_t>(),
cos_sin_cache.data_ptr<scalar_t>(),
num_seqs,
num_heads,
rotary_dim,
num_heads * qk_head_dim,
qk_head_dim,
kv_lora_rank + qk_rope_head_dim,
num_heads * (kv_lora_rank + qk_rope_head_dim),
kv_lora_rank + qk_rope_head_dim,
kv_lora_rank + qk_rope_head_dim);
});
return std::make_tuple(q_input, k_input, v_input);
}
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t>
void rope_kernel_impl(
scalar_t* __restrict__ q_pe_out,
scalar_t* __restrict__ k_pe_out,
int64_t* __restrict__ t_pos,
scalar_t* __restrict__ q_pe,
scalar_t* __restrict__ k_pe,
scalar_t* __restrict__ t_emb_pos,
int64_t seq_len,
int64_t num_head,
int64_t rotary_dim,
int64_t HR,
int64_t q_pe_stride_s,
int64_t out_stride_qs,
int64_t out_stride_ks,
int64_t HK,
int64_t k_pe_stride_s,
int64_t q_pe_stride_n,
int64_t out_stride_qn) {
int64_t COFF = HR / 2;
at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t seq{0}, head_id{0};
data_index_init(begin, seq, seq_len, head_id, num_head);
for (int64_t i = begin; i < end; ++i) {
int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n;
int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn;
int64_t out_offset_k = seq * out_stride_ks;
int64_t p = 0;
scalar_t* sin_start = nullptr;
scalar_t* cos_start = nullptr;
// step 0) get the rotary position embedding for the current position
p = t_pos[seq];
sin_start = t_emb_pos + p * HR + COFF;
cos_start = t_emb_pos + p * HR;
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
// head of query/key
for (int64_t h = 0; h < rotary_dim; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
scalar_t in1 = q_pe[in_offset_q + h];
scalar_t in2 = q_pe[in_offset_q + h + 1];
scalar_t out1 = in1 * cos - in2 * sin;
scalar_t out2 = in2 * cos + in1 * sin;
q_pe_out[out_offset_q + h] = out1;
q_pe_out[out_offset_q + h + 1] = out2;
}
for (int64_t h = 0; h < HK; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
int64_t k_pe_offset = seq * k_pe_stride_s;
scalar_t in1_k = k_pe[k_pe_offset + h];
scalar_t in2_k = k_pe[k_pe_offset + h + 1];
scalar_t out1_k = in1_k * cos - in2_k * sin;
scalar_t out2_k = in2_k * cos + in1_k * sin;
k_pe_out[out_offset_k + h] = out1_k;
k_pe_out[out_offset_k + h + 1] = out2_k;
}
// move to the next index
data_index_step(seq, seq_len, head_id, num_head);
}
});
}
} // namespace
std::tuple<at::Tensor, at::Tensor>
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) {
RECORD_FUNCTION(
"sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos}));
CHECK_INPUT(t_pos);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe);
CHECK_INPUT(t_emb_pos);
CHECK_DIM(1, t_pos);
CHECK_DIM(3, q_pe);
CHECK_DIM(3, k_pe);
CHECK_DIM(2, t_emb_pos);
int64_t seq_len = q_pe.size(0);
int64_t num_head = q_pe.size(1);
int64_t rotary_dim = q_pe.size(2);
int64_t HK = k_pe.size(2);
int64_t HR = t_emb_pos.size(1);
CHECK_EQ(HR, rotary_dim);
CHECK_EQ(k_pe.size(0), seq_len);
CHECK_EQ(k_pe.size(1), 1);
CHECK_EQ(t_pos.size(0), seq_len);
CHECK_EQ(HK, rotary_dim);
at::Tensor q_pe_out = at::empty_like(q_pe);
at::Tensor k_pe_out = at::empty_like(k_pe);
int64_t q_pe_stride_s = q_pe.stride(0);
int64_t q_pe_stride_n = q_pe.stride(1);
int64_t k_pe_stride_s = k_pe.stride(0);
int64_t out_stride_qs = q_pe_out.stride(0);
int64_t out_stride_qn = q_pe_out.stride(1);
int64_t out_stride_ks = k_pe_out.stride(0);
const auto input_dtype = q_pe.scalar_type();
TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type());
TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type");
TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type");
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] {
rope_kernel_impl<scalar_t>(
q_pe_out.data_ptr<scalar_t>(),
k_pe_out.data_ptr<scalar_t>(),
t_pos.data_ptr<int64_t>(),
q_pe.data_ptr<scalar_t>(),
k_pe.data_ptr<scalar_t>(),
t_emb_pos.data_ptr<scalar_t>(),
seq_len,
num_head,
rotary_dim,
HR,
q_pe_stride_s,
out_stride_qs,
out_stride_ks,
HK,
k_pe_stride_s,
q_pe_stride_n,
out_stride_qn);
});
return std::make_tuple(q_pe_out, k_pe_out);
}
#include "shm.h"
#include <ATen/ATen.h>
#include <errno.h>
#include <fcntl.h>
#include <immintrin.h>
#include <stddef.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/mman.h>
#include <unistd.h>
// states for collectives
enum coll_state {
coll_begin = 0,
coll_allreduce_naive__copy_in_done,
coll_allreduce_naive__reduce_done,
// alternative state when allreduce is working on alternative buffer
// of the double buffer.
coll_alt1_allreduce_naive__copy_in_done,
coll_alt2_allreduce_naive__copy_in_done,
coll_alt1_allreduce_naive__reduce_done,
coll_allgather_naive__copy_in_done,
coll_alt1_allgather_naive__copy_in_done,
coll_alt2_allgather_naive__copy_in_done,
};
// SHM building blocks
struct SharedData {
const char* name;
int descriptor;
void* bytes;
size_t nbytes;
};
void shared_open(SharedData* data, const char* name, size_t nbytes) {
int d = shm_open(name, O_RDWR, S_IRUSR | S_IWUSR);
if (d != -1) {
void* bytes = mmap(NULL, nbytes, PROT_READ | PROT_WRITE, MAP_SHARED, d, 0);
data->name = name;
data->descriptor = d;
data->bytes = bytes;
data->nbytes = nbytes;
} else {
if (errno != ENOENT) {
// don't print if shm can not be found because we want to loop over from
// caller again until the other ranks created the shm
printf("shared_open %s failed, errno=%d\n", name, errno);
}
data->descriptor = -1;
}
}
void shared_create(SharedData* data, const char* name, void* bytes, size_t nbytes) {
int d = shm_open(name, O_CREAT | O_RDWR, S_IRUSR | S_IWUSR);
if (d != -1) {
if (nbytes = write(d, bytes, nbytes)) {
shared_open(data, name, nbytes);
}
} else {
printf("shared_create %s failed\n", name);
}
}
static int world_size;
// SHM based allreduce helper functions
// buffer that holds shm name
#define NAME_BUF_SIZE 1000
#define MAX_BUF_SIZE 1048576 * 32
#define NAIVE_ALLREDUCE_THRESHOLD 1048576
#define SHM_BUFFER_NAME "deepspeed_allreduce_buffer"
struct allreduce_workspace {
enum coll_state states[2]; // idx=0 -- state for symmetric_naive_all_reduce
// idx=1 -- state for distributed_naive_all_reduce
// double buffer to avoid syncing between rounds
// offset=0 -- 2*NAIVE_ALLREDUCE_THRESHOLD : buffer for
// symmetric_naive_all_reduce after that : buffer for
// distributed_naive_all_reduce
char buffer[2 * NAIVE_ALLREDUCE_THRESHOLD + 2 * MAX_BUF_SIZE];
};
#define BUFFER0_OFFSET(current_buffer) current_buffer* NAIVE_ALLREDUCE_THRESHOLD
#define BUFFER1_OFFSET(current_buffer) 2 * NAIVE_ALLREDUCE_THRESHOLD + current_buffer* MAX_BUF_SIZE
struct allreduce_workspace** workspace;
// buffer for small messages, double buffer
char** symmetric_buffer[2];
// buffer for large messages, double buffer
char** distributed_buffer[2];
void wait_buffer_state_until_2(int index, enum coll_state state0, enum coll_state state1, int state_group) {
volatile enum coll_state* state_ptr = &(workspace[index]->states[state_group]);
while (1) {
volatile enum coll_state cur_state = *state_ptr;
if (cur_state == state0 || cur_state == state1) break;
}
}
__m512 cvt_bf16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_bf16_to_fp32(const __m256i src) {
auto y = _mm512_cvtepu16_epi32(src);
return _mm512_castsi512_ps(_mm512_bslli_epi128(y, 2));
}
inline __m256i cvt_fp32_to_bf16(const __m512 src) __attribute__((target("avx512bw")));
inline __m256i cvt_fp32_to_bf16(const __m512 src) {
__m512i value = _mm512_castps_si512(src);
__m512i nan = _mm512_set1_epi32(0xffff);
auto mask_value = _mm512_cmp_ps_mask(src, src, _CMP_ORD_Q);
__m512i ones = _mm512_set1_epi32(0x1);
__m512i vec_bias = _mm512_set1_epi32(0x7fff);
// uint32_t lsb = (input >> 16) & 1;
auto t_value = _mm512_and_si512(_mm512_srli_epi32(value, 16), ones);
// uint32_t rounding_bias = 0x7fff + lsb;
t_value = _mm512_add_epi32(t_value, vec_bias);
// input += rounding_bias;
t_value = _mm512_add_epi32(t_value, value);
// input = input >> 16;
t_value = _mm512_srli_epi32(t_value, 16);
// Check NaN before converting back to bf16
t_value = _mm512_mask_blend_epi32(mask_value, nan, t_value);
return _mm512_cvtusepi32_epi16(t_value);
}
__m512 cvt_fp16_to_fp32(const __m256i src) __attribute__((target("avx512bw")));
inline __m512 cvt_fp16_to_fp32(const __m256i src) {
return _mm512_cvtph_ps(src);
}
inline __m256i cvt_fp32_to_fp16(const __m512 src) __attribute__((target("avx512bw")));
inline __m256i cvt_fp32_to_fp16(const __m512 src) {
return _mm512_cvtps_ph(src, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
}
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers)
__attribute__((target("avx512bw")));
void reduce_all_buffers(
int start_elements,
int num_elements,
c10::ScalarType scalar_type,
int to_buffer_idx,
char* to_buffer,
char** buffers) {
switch (scalar_type) {
case c10::ScalarType::BFloat16:
reduce_bf16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Half:
reduce_fp16_buffers(start_elements, num_elements, to_buffer, buffers);
break;
case c10::ScalarType::Float:
reduce_fp32_buffers(start_elements, num_elements, to_buffer, buffers);
break;
default:
assert(!"Should not get here");
}
}
#define CVT_ADD_BF16(x) \
do { \
auto in##x##_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0)
// Reduce functions down below use vectorized algorithm, the number of bytes
// processed each iteration depends on vector length. 256bit vector ==> 32
// bytes, 512bit vector ==> 64 bytes If you change implementation of
// reduce_bf16_buffers, etc. , check whether this number needs to be changed
#define VECTOR_LENGTH_IN_BYTES 32
void reduce_bf16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (world_size) {
case 16:
CVT_ADD_BF16(15);
case 15:
CVT_ADD_BF16(14);
case 14:
CVT_ADD_BF16(13);
case 13:
CVT_ADD_BF16(12);
case 12:
CVT_ADD_BF16(11);
case 11:
CVT_ADD_BF16(10);
case 10:
CVT_ADD_BF16(9);
case 9:
CVT_ADD_BF16(8);
case 8:
CVT_ADD_BF16(7);
case 7:
CVT_ADD_BF16(6);
case 6:
CVT_ADD_BF16(5);
case 5:
CVT_ADD_BF16(4);
case 4:
CVT_ADD_BF16(3);
case 3:
CVT_ADD_BF16(2);
case 2:
CVT_ADD_BF16(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_bf16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
}
}
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_bf16(inout_val));
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(at::BFloat16*)(buffers[j] + i);
}
*(at::BFloat16*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
#define CVT_ADD_FP16(x) \
do { \
auto in##x##_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[x] + i))); \
inout_val = _mm512_add_ps(inout_val, in##x##_val); \
} while (0)
void reduce_fp16_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 2;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[0] + i)));
switch (world_size) {
case 16:
CVT_ADD_FP16(15);
case 15:
CVT_ADD_FP16(14);
case 14:
CVT_ADD_FP16(13);
case 13:
CVT_ADD_FP16(12);
case 12:
CVT_ADD_FP16(11);
case 11:
CVT_ADD_FP16(10);
case 10:
CVT_ADD_FP16(9);
case 9:
CVT_ADD_FP16(8);
case 8:
CVT_ADD_FP16(7);
case 7:
CVT_ADD_FP16(6);
case 6:
CVT_ADD_FP16(5);
case 5:
CVT_ADD_FP16(4);
case 4:
CVT_ADD_FP16(3);
case 3:
CVT_ADD_FP16(2);
case 2:
CVT_ADD_FP16(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = cvt_fp16_to_fp32(_mm256_loadu_si256((__m256i*)(buffers[j] + i)));
inout_val = _mm512_add_ps(inout_val, in_val);
}
}
_mm256_storeu_si256((__m256i*)(to_buffer + i), cvt_fp32_to_fp16(inout_val));
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(at::Half*)(buffers[j] + i);
}
*(at::Half*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
#define CVT_ADD_F32(x) \
do { \
auto in##x##_val = _mm256_loadu_ps((float*)(buffers[x] + i)); \
inout_val = _mm256_add_ps(inout_val, in##x##_val); \
} while (0)
void reduce_fp32_buffers(int start_elements, int num_elements, char* to_buffer, char** buffers) {
const int element_size = 4;
const int vector_length = VECTOR_LENGTH_IN_BYTES / element_size;
int main_elements = num_elements - (num_elements % vector_length);
int remain_elements = num_elements % vector_length;
// process aligned part
#pragma omp parallel for
for (int i = start_elements * element_size; i < (start_elements + main_elements) * element_size;
i += VECTOR_LENGTH_IN_BYTES) {
auto inout_val = _mm256_loadu_ps((float*)(buffers[0] + i));
switch (world_size) {
case 16:
CVT_ADD_F32(15);
case 15:
CVT_ADD_F32(14);
case 14:
CVT_ADD_F32(13);
case 13:
CVT_ADD_F32(12);
case 12:
CVT_ADD_F32(11);
case 11:
CVT_ADD_F32(10);
case 10:
CVT_ADD_F32(9);
case 9:
CVT_ADD_F32(8);
case 8:
CVT_ADD_F32(7);
case 7:
CVT_ADD_F32(6);
case 6:
CVT_ADD_F32(5);
case 5:
CVT_ADD_F32(4);
case 4:
CVT_ADD_F32(3);
case 3:
CVT_ADD_F32(2);
case 2:
CVT_ADD_F32(1);
case 1:
break;
default:
for (int j = 1; j < world_size; j++) {
auto in_val = _mm256_loadu_ps((float*)(buffers[j] + i));
inout_val = _mm256_add_ps(inout_val, in_val);
}
}
_mm256_storeu_ps((float*)(to_buffer + i), inout_val);
}
// process remaining part
int i = (start_elements + main_elements) * element_size;
while (remain_elements > 0) {
float val = 0.0f;
for (int j = 0; j < world_size; j++) {
val += *(float*)(buffers[j] + i);
}
*(float*)(to_buffer + i) = val;
remain_elements--;
i += element_size;
}
}
static bool is_initialized = false;
static int world_rank;
void shm_initialize(int size, int rank, char* addr_string, char* port_string) {
if (is_initialized) {
return;
}
is_initialized = true;
world_size = size;
world_rank = rank;
char shm_name_prefix[NAME_BUF_SIZE];
char shm_name[NAME_BUF_SIZE];
snprintf(shm_name_prefix, NAME_BUF_SIZE, "%s_%d_%s_%s", SHM_BUFFER_NAME, getuid(), addr_string, port_string);
// create shared workspace for SHM based allreduce
SharedData allreduce_buffer;
// allocate workspace_buf for current rank
struct allreduce_workspace* workspace_buf;
struct allreduce_workspace* workspace_buf_other;
workspace_buf = (struct allreduce_workspace*)malloc(sizeof(struct allreduce_workspace));
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, rank);
shared_create(&allreduce_buffer, shm_name, workspace_buf, sizeof(struct allreduce_workspace));
workspace_buf = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace_buf->states[0] = coll_alt2_allreduce_naive__copy_in_done;
workspace_buf->states[1] = coll_begin;
// create the workspace pointer list
workspace = (struct allreduce_workspace**)malloc(size * sizeof(struct allreduce_workspace*));
symmetric_buffer[0] = (char**)malloc(size * sizeof(char**));
symmetric_buffer[1] = (char**)malloc(size * sizeof(char**));
distributed_buffer[0] = (char**)malloc(size * sizeof(char**));
distributed_buffer[1] = (char**)malloc(size * sizeof(char**));
// map shm of all ranks
for (int i = 0; i < size; i++) {
if (i != rank) {
snprintf(shm_name, NAME_BUF_SIZE, "%s_%d", shm_name_prefix, i);
// printf("open %s, %d\n", shm_name, rank);
do {
shared_open(&allreduce_buffer, shm_name, sizeof(struct allreduce_workspace));
} while (allreduce_buffer.descriptor == -1 && errno == ENOENT);
workspace_buf_other = (struct allreduce_workspace*)allreduce_buffer.bytes;
workspace[i] = workspace_buf_other;
} else {
workspace[i] = workspace_buf;
}
symmetric_buffer[0][i] = workspace[i]->buffer + BUFFER0_OFFSET(0);
symmetric_buffer[1][i] = workspace[i]->buffer + BUFFER0_OFFSET(1);
distributed_buffer[0][i] = workspace[i]->buffer + BUFFER1_OFFSET(0);
distributed_buffer[1][i] = workspace[i]->buffer + BUFFER1_OFFSET(1);
}
}
static void parallel_memcpy(void* to, void* from, size_t n_bytes) __attribute__((target("avx512bw")));
static void parallel_memcpy(void* to, void* from, size_t n_bytes) {
auto aligned_bytes = n_bytes - (n_bytes % VECTOR_LENGTH_IN_BYTES);
// process aligned part
#pragma omp parallel for
for (int i = 0; i < aligned_bytes; i += VECTOR_LENGTH_IN_BYTES) {
auto val = _mm256_loadu_si256((__m256i*)((char*)from + i));
_mm256_storeu_si256((__m256i*)((char*)to + i), val);
}
// process remaining part
for (int i = aligned_bytes; i < n_bytes; i++) {
*((char*)to + i) = *((char*)from + i);
}
}
#define positive_mod(num, mod) ((((num) % (mod)) + (mod)) % (mod))
#define rank_mod(rank) positive_mod(rank, world_size)
size_t slice_size(size_t chunk_el, int slice_idx) {
size_t slice_size = chunk_el / world_size;
return slice_idx == world_size - 1 ? slice_size + (chunk_el % world_size) : slice_size;
}
char* slice_data(char* data_ptr, size_t chunk_el, int el_size, int slice_idx) {
size_t slice_size = chunk_el / world_size;
size_t el_offset = slice_size * slice_idx;
return data_ptr + el_offset * el_size;
}
size_t slice_el_start(size_t chunk_el, int slice_idx) {
size_t slice_size = chunk_el / world_size;
return slice_size * slice_idx;
}
void symmetric_naive_all_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
const int state_group = 0;
static int current_buffer = 0;
static int state_idx = 0;
enum coll_state copy_current, copy_next;
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
copy_next = coll_alt2_allreduce_naive__copy_in_done;
break;
case 2:
copy_current = coll_alt2_allreduce_naive__copy_in_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 3;
parallel_memcpy(symmetric_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until the other rank copy the buffer
if (i != world_rank) {
wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
}
}
// each rank reduce the buffer independently so therre is no need for
// synchronization afterward
reduce_all_buffers(0, chunk_el, scalar_type, world_rank, data_ptr, symmetric_buffer[current_buffer]);
// switch buffer
current_buffer = 1 - current_buffer;
}
// naive allreduce distributed, each rank do naive reduce on its slice
void distributed_naive_reduce(char* data_ptr, c10::ScalarType scalar_type, size_t chunk_size, size_t chunk_el) {
const int state_group = 1;
static int current_buffer = 0;
static int state_idx = 0;
enum coll_state copy_current, copy_next, reduce_current;
// similar to symmetric_naive_allreduce, but here we only need two sets of
// states, because distributed naive reduce has two barriers in the algorithm
switch (state_idx) {
case 0:
copy_current = coll_allreduce_naive__copy_in_done;
reduce_current = coll_allreduce_naive__reduce_done;
copy_next = coll_alt1_allreduce_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allreduce_naive__copy_in_done;
reduce_current = coll_alt1_allreduce_naive__reduce_done;
copy_next = coll_allreduce_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 2;
int data_size = chunk_size / chunk_el;
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks copy the buffer
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, reduce_current, state_group);
}
// reduce scatter
reduce_all_buffers(
slice_el_start(chunk_el, world_rank),
slice_size(chunk_el, world_rank),
scalar_type,
world_rank,
distributed_buffer[current_buffer][world_rank],
distributed_buffer[current_buffer]);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = reduce_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks reduce the buffer
if (i != world_rank) wait_buffer_state_until_2(i, reduce_current, copy_next, state_group);
}
for (int i = 0; i < world_size; i++) {
int rank = (i + world_rank) % world_size;
parallel_memcpy(
slice_data(data_ptr, chunk_el, data_size, rank),
slice_data(distributed_buffer[current_buffer][rank], chunk_el, chunk_size / chunk_el, rank),
slice_size(chunk_el, rank) * data_size);
}
current_buffer = 1 - current_buffer;
}
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size) {
for (int offset = 0; offset < data_size; offset += MAX_BUF_SIZE) {
auto data_ptr = ((char*)(data.data_ptr()) + offset);
size_t chunk_size = data_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : data_size - offset;
size_t chunk_el = chunk_size / (data_size / numel);
if (chunk_size < NAIVE_ALLREDUCE_THRESHOLD) {
symmetric_naive_all_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
} else {
distributed_naive_reduce(data_ptr, data.scalar_type(), chunk_size, chunk_el);
}
}
}
void naive_all_gather(char* result_ptr, char* data_ptr, size_t res_stride, size_t chunk_size, size_t chunk_el) {
const int state_group = 1;
static int current_buffer = 0;
static int state_idx = 0;
enum coll_state copy_current, copy_next;
switch (state_idx) {
case 0:
copy_current = coll_allgather_naive__copy_in_done;
copy_next = coll_alt1_allgather_naive__copy_in_done;
break;
case 1:
copy_current = coll_alt1_allgather_naive__copy_in_done;
copy_next = coll_alt2_allgather_naive__copy_in_done;
break;
case 2:
copy_current = coll_alt2_allgather_naive__copy_in_done;
copy_next = coll_allgather_naive__copy_in_done;
break;
default:
assert(!"Should not get here.");
}
state_idx = (state_idx + 1) % 3;
int data_size = chunk_size / chunk_el;
parallel_memcpy(distributed_buffer[current_buffer][world_rank], data_ptr, chunk_size);
std::atomic_thread_fence(std::memory_order_release);
workspace[world_rank]->states[state_group] = copy_current;
for (int i = 0; i < world_size; i++) {
// wait until all the other ranks copy the buffer
if (i != world_rank) wait_buffer_state_until_2(i, copy_current, copy_next, state_group);
}
for (int i = 0; i < world_size; i++) {
parallel_memcpy(result_ptr + i * res_stride, distributed_buffer[current_buffer][i], chunk_size);
}
current_buffer = 1 - current_buffer;
}
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size) {
size_t dim_el = data.stride(dim) * data.size(dim);
int dtype_size = data_size / numel;
size_t dim_size = dim_el * dtype_size;
int dim_count = data_size / dim_size;
auto data_ptr = (char*)(data.data_ptr());
auto result_ptr = (char*)(result.data_ptr());
for (int i = 0; i < dim_count; i++) {
for (int offset = 0; offset < dim_size; offset += MAX_BUF_SIZE) {
size_t chunk_size = dim_size - offset > MAX_BUF_SIZE ? MAX_BUF_SIZE : dim_size - offset;
size_t chunk_el = chunk_size / dtype_size;
naive_all_gather(
result_ptr + i * dim_size * world_size + offset,
data_ptr + i * dim_size + offset,
dim_size,
chunk_size,
chunk_el);
}
}
return result;
}
#include <torch/torch.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#ifndef __SHM_COLLECTIVES__
#define __SHM_COLLECTIVES__
#define VECTOR_LENGTH_IN_BYTES 32
void shm_initialize(int size, int rank, char* addr_string, char* port_string);
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size);
#endif
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t, int SIZE>
inline void softmax(float* __restrict__ out, const scalar_t* __restrict__ input) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
// step 1: get max
fVec max_fvec = fVec(-std::numeric_limits<float>::infinity());
if constexpr (SIZE < kVecSize) {
// SIZE = 1, 2, 4, 8, 16; only the top half is used
bVec x_bvec = bVec::loadu(input, SIZE);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
x_fvec0 = fVec::set(max_fvec, x_fvec0, SIZE);
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
x_fvec0.store(out, SIZE);
} else {
for (int d = 0; d < SIZE; d += kVecSize) {
bVec x_bvec = bVec::loadu(input + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
max_fvec = at::vec::maximum(max_fvec, x_fvec0);
max_fvec = at::vec::maximum(max_fvec, x_fvec1);
x_fvec0.store(out + d);
x_fvec1.store(out + d + fVec::size());
}
}
float max_val = vec_reduce_max(max_fvec);
max_fvec = fVec(max_val);
// step 2: sum of (x - max).exp()
fVec sum_fvec = fVec(float(0));
if constexpr (SIZE < fVec::size()) {
// SIZE = 1, 2, 4, 8
fVec x_fvec = (fVec::loadu(out, SIZE) - max_fvec).exp_u20();
x_fvec = fVec::set(sum_fvec, x_fvec, SIZE);
sum_fvec += x_fvec;
x_fvec.store(out, SIZE);
} else {
for (int d = 0; d < SIZE; d += fVec::size()) {
fVec x_fvec = (fVec::loadu(out + d) - max_fvec).exp_u20();
sum_fvec += x_fvec;
x_fvec.store(out + d);
}
}
float sum_val = vec_reduce_sum(sum_fvec);
// step 3: x * (1 / sum)
sum_fvec = fVec(1.f / sum_val);
if constexpr (SIZE < fVec::size()) {
// SIZE = 1, 2, 4, 8
fVec out_fvec = fVec::loadu(out, SIZE) * sum_fvec;
out_fvec.store(out, SIZE);
} else {
for (int d = 0; d < SIZE; d += fVec::size()) {
fVec out_fvec = fVec::loadu(out + d) * sum_fvec;
out_fvec.store(out + d);
}
}
}
template <typename scalar_t, int NUM_EXPERTS>
void grouped_topk_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
int64_t num_tokens,
int64_t topk,
int64_t num_groups,
int64_t topk_group,
bool renormalize) {
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
alignas(64) float scores[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_groups);
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
// do softmax to get scores
softmax<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
// find max score per group
for (int64_t g = 0; g < num_groups; ++g) {
float gmax = -std::numeric_limits<float>::infinity();
for (int64_t e = 0; e < num_experts_per_group; ++e) {
gmax = std::max(gmax, scores[g * num_experts_per_group + e]);
}
queue[g] = {gmax, g};
}
// find group topk
std::partial_sort(
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int64_t g = 0; g < topk_group; ++g) {
int32_t group_idx = queue[g].second;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
int32_t expert_idx = group_idx * num_experts_per_group + e;
queue2[g * num_experts_per_group + e] = {scores[expert_idx], expert_idx};
}
}
// find global topk
std::partial_sort(
queue2.begin(), queue2.begin() + topk, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] = queue2[j].first;
topk_ids[i * topk + j] = queue2[j].second;
}
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < topk; ++j) {
sum += topk_weights[i * topk + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < topk; ++j) {
topk_weights[i * topk + j] *= scale;
}
}
}
});
}
template <typename scalar_t, int SIZE>
inline void sigmoid(float* __restrict__ out, const scalar_t* __restrict__ input) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
const fVec one = fVec(1.f);
constexpr int kVecSize = bVec::size();
for (int d = 0; d < SIZE; d += kVecSize) {
bVec x_bvec = bVec::loadu(input + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
x_fvec0 = one / (one + x_fvec0.neg().exp_u20());
x_fvec1 = one / (one + x_fvec1.neg().exp_u20());
x_fvec0.store(out + d);
x_fvec1.store(out + d + fVec::size());
}
}
template <typename scalar_t, int SIZE>
inline void
apply_bias(float* __restrict__ scores2, const float* __restrict__ scores, const scalar_t* __restrict__ bias) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
for (int d = 0; d < SIZE; d += bVec::size()) {
bVec bias_vec = bVec::loadu(bias + d);
fVec bias0, bias1;
std::tie(bias0, bias1) = at::vec::convert_to_float(bias_vec);
fVec x0 = fVec::loadu(scores + d) + bias0;
fVec x1 = fVec::loadu(scores + d + fVec::size()) + bias1;
x0.store(scores2 + d);
x1.store(scores2 + d + fVec::size());
}
}
template <typename scalar_t, int NUM_EXPERTS, int TOPK>
void biased_grouped_topk_kernel_impl(
float* __restrict__ topk_weights,
int32_t* __restrict__ topk_ids,
const scalar_t* __restrict__ gating_output,
const scalar_t* __restrict__ bias,
int64_t num_tokens,
int64_t num_groups,
int64_t topk_group,
bool renormalize) {
using Vec = at::vec::Vectorized<float>;
const int64_t num_experts_per_group = NUM_EXPERTS / num_groups;
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
// scores: sigmoid
alignas(64) float scores[NUM_EXPERTS];
// scores for choice: sigmoid + bias
alignas(64) float scores2[NUM_EXPERTS];
using elem_t = std::pair<float, int32_t>;
std::vector<elem_t> queue(num_groups);
std::vector<elem_t> queue2(topk_group * num_experts_per_group);
for (int64_t i = begin; i < end; ++i) {
// do sigmoid to get scores
sigmoid<scalar_t, NUM_EXPERTS>(scores, gating_output + i * NUM_EXPERTS);
apply_bias<scalar_t, NUM_EXPERTS>(scores2, scores, bias);
for (int64_t g = 0; g < num_groups; ++g) {
// find the max
float gmax = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
scores2 + g * num_experts_per_group,
num_experts_per_group);
// find position of first max,
// note that we may have multiple max values.
int first_max_idx = -1;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
if (scores2[g * num_experts_per_group + e] == gmax) {
first_max_idx = g * num_experts_per_group + e;
break;
}
}
// find the 2nd max
scores2[first_max_idx] = -std::numeric_limits<float>::infinity();
float gmax2 = at::vec::reduce_all<float>(
[](Vec& x, Vec& y) { return at::vec::maximum(x, y); },
scores2 + g * num_experts_per_group,
num_experts_per_group);
// restore scores for choice
scores2[first_max_idx] = gmax;
queue[g] = {gmax + gmax2, g};
}
// find group topk
std::partial_sort(
queue.begin(), queue.begin() + topk_group, queue.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int64_t g = 0; g < topk_group; ++g) {
int32_t group_idx = queue[g].second;
for (int64_t e = 0; e < num_experts_per_group; ++e) {
int32_t expert_idx = group_idx * num_experts_per_group + e;
queue2[g * num_experts_per_group + e] = {scores2[expert_idx], expert_idx};
}
}
// find global topk
std::partial_sort(
queue2.begin(), queue2.begin() + TOPK, queue2.end(), [](const elem_t& x, const elem_t& y) -> bool {
return x.first > y.first;
});
for (int j = 0; j < TOPK; ++j) {
int32_t index = queue2[j].second;
topk_ids[i * TOPK + j] = index;
topk_weights[i * TOPK + j] = scores[index];
}
#if defined(CPU_CAPABILITY_AVX512)
if (renormalize) {
__mmask16 mask = (1ULL << TOPK) - 1;
__m512 x = _mm512_maskz_loadu_ps(mask, topk_weights + i * TOPK);
float sum = _mm512_reduce_add_ps(x);
__m512 vscale = _mm512_set1_ps(1.f / sum);
__m512 y = _mm512_mul_ps(x, vscale);
_mm512_mask_storeu_ps(topk_weights + i * TOPK, mask, y);
}
#else
if (renormalize) {
float sum = 0.f;
for (int64_t j = 0; j < TOPK; ++j) {
sum += topk_weights[i * TOPK + j];
}
float scale = 1.f / sum;
for (int64_t j = 0; j < TOPK; ++j) {
topk_weights[i * TOPK + j] *= scale;
}
}
#endif
}
});
}
#define LAUNCH_GROUPED_TOPK_KERNEL(NE) \
grouped_topk_kernel_impl<scalar_t, NE>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
num_tokens, \
topk, \
num_expert_group, \
topk_group, \
renormalize);
#define LAUNCH_BIASED_GROUPED_TOPK_KERNEL(NE, NTOPK) \
biased_grouped_topk_kernel_impl<scalar_t, NE, NTOPK>( \
topk_weights.data_ptr<float>(), \
topk_ids.data_ptr<int32_t>(), \
gating_output.data_ptr<scalar_t>(), \
correction_bias.data_ptr<scalar_t>(), \
num_tokens, \
num_expert_group, \
topk_group, \
renormalize);
} // anonymous namespace
// grouped topk for DeepSeek V2
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group) {
RECORD_FUNCTION("sgl-kernel::grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output}));
CHECK_INPUT(gating_output);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "grouped_topk_kernel", [&] {
switch (num_experts) {
case 1:
LAUNCH_GROUPED_TOPK_KERNEL(1);
break;
case 2:
LAUNCH_GROUPED_TOPK_KERNEL(2);
break;
case 4:
LAUNCH_GROUPED_TOPK_KERNEL(4);
break;
case 8:
LAUNCH_GROUPED_TOPK_KERNEL(8);
break;
case 16:
LAUNCH_GROUPED_TOPK_KERNEL(16);
break;
case 32:
LAUNCH_GROUPED_TOPK_KERNEL(32);
break;
case 64:
LAUNCH_GROUPED_TOPK_KERNEL(64);
break;
case 128:
LAUNCH_GROUPED_TOPK_KERNEL(128);
break;
case 160:
LAUNCH_GROUPED_TOPK_KERNEL(160);
break;
case 256:
LAUNCH_GROUPED_TOPK_KERNEL(256);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
// biased grouped topk DeepSeek V3/R1
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
at::Tensor& correction_bias,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group) {
RECORD_FUNCTION(
"sgl-kernel::biased_grouped_topk_cpu", std::vector<c10::IValue>({hidden_states, gating_output, correction_bias}));
CHECK_INPUT(gating_output);
CHECK_INPUT(correction_bias);
const auto st = hidden_states.scalar_type();
CHECK_EQ(gating_output.scalar_type(), st);
CHECK_EQ(correction_bias.scalar_type(), st);
int64_t num_tokens = hidden_states.size(0);
int64_t num_experts = gating_output.size(1);
TORCH_CHECK(gating_output.size(0) == num_tokens, "Number of tokens mismatch");
TORCH_CHECK(correction_bias.numel() == num_experts, "Bias shape mismatch");
at::Tensor topk_weights = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kFloat));
at::Tensor topk_ids = at::empty({num_tokens, topk}, hidden_states.options().dtype(at::kInt));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "biased_grouped_topk_kernel", [&] {
// NOW only support DSv3 configs
TORCH_CHECK(topk == 8, "Unexpected topk: ", topk);
switch (num_experts) {
case 256:
LAUNCH_BIASED_GROUPED_TOPK_KERNEL(256, 8);
break;
default:
TORCH_CHECK(false, "Unexpected num_experts: ", num_experts);
}
});
return std::make_tuple(topk_weights, topk_ids);
}
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <ATen/ATen.h>
#include <torch/extension.h>
#include <torch/library.h>
#include "shm.h"
// silu_and_mul
at::Tensor silu_and_mul_cpu(at::Tensor& input);
// rmsnorm
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps);
// fused_add_rmsnorm
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps);
// topk
std::tuple<at::Tensor, at::Tensor> grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group);
std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
at::Tensor& hidden_states,
at::Tensor& gating_output,
at::Tensor& correction_bias,
int64_t topk,
bool renormalize,
int64_t num_expert_group,
int64_t topk_group);
// attention
void decode_attention_cpu(
at::Tensor& query,
at::Tensor& output,
at::Tensor& k_cache,
at::Tensor& v_cahce,
at::Tensor& attn_logits,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
double sm_scale,
double logit_cap);
void extend_attention_cpu(
at::Tensor& q_extend,
at::Tensor& k_extend,
at::Tensor& v_extend,
at::Tensor& o_extend,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
at::Tensor& seq_lens,
at::Tensor& extend_seq_lens,
at::Tensor& extend_start_loc,
int64_t max_len_extend,
double sm_scale,
double logit_cap);
// weight prepack
at::Tensor convert_weight_packed(at::Tensor& weight);
// quant
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A);
// gemm
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni);
// igemm
at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales1,
at::Tensor& scales2,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// quant + igemm
at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni);
// bmm
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale);
// fused moe
at::Tensor fused_experts_cpu(
at::Tensor& hidden_states,
at::Tensor& w1,
at::Tensor& w2,
at::Tensor& topk_weights,
at::Tensor& topk_ids,
bool inplace,
bool use_int8_w8a8,
std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale,
std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale,
bool is_vnni);
at::Tensor shared_expert_cpu(
at::Tensor& hidden_states,
at::Tensor& w1,
at::Tensor& w2,
at::Tensor& fused_experts_out,
double routed_scaling_factor,
bool inplace,
bool use_int8_w8a8,
std::optional<at::Tensor>& w1_scale,
std::optional<at::Tensor>& w2_scale,
std::optional<at::Tensor>& a1_scale,
std::optional<at::Tensor>& a2_scale,
bool is_vnni);
// weight absorption
std::tuple<at::Tensor, at::Tensor, at::Tensor> qkv_proj_with_rope(
at::Tensor& hidden_states,
at::Tensor& q_a_proj_weight,
at::Tensor& q_b_proj_weight,
at::Tensor& kv_a_proj_weight,
at::Tensor& w_kc,
at::Tensor& q_a_layernorm_weight,
at::Tensor& kv_a_layernorm_weight,
at::Tensor& positions,
at::Tensor& cos_sin_cache,
double eps,
bool use_int8_w8a8,
std::optional<at::Tensor>& q_a_proj_scale,
std::optional<at::Tensor>& q_b_proj_scale,
std::optional<at::Tensor>& kv_a_proj_scale,
bool is_vnni);
// shared memory init
void initialize(int size, int rank);
// shared mmeory all_reduce
void shm_allreduce(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op);
// shared memory all_gather
at::Tensor shm_allgather(at::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim);
// rope
std::tuple<at::Tensor, at::Tensor>
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos);
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// activation
m.def("silu_and_mul_cpu", &silu_and_mul_cpu, "SiLU and mul for CPU");
// norm
m.def("rmsnorm_cpu", &rmsnorm_cpu, "Root mean square normalization for CPU");
m.def("fused_add_rmsnorm_cpu", &fused_add_rmsnorm_cpu, "Fused add root mean square normalization for CPU");
// topk
m.def("grouped_topk_cpu", &grouped_topk_cpu, "Grouped TopK for CPU");
// biased group topk
m.def("biased_grouped_topk_cpu", &biased_grouped_topk_cpu, "Biased Grouped TopK for CPU");
// decode
m.def("decode_attention_cpu", &decode_attention_cpu, "Attention decoding for CPU");
// extend
m.def("extend_attention_cpu", &extend_attention_cpu, "Attention extend for CPU");
// weight prepack
m.def("convert_weight_packed", &convert_weight_packed, "prepack weight to vnni format for intel AMX");
// quant
m.def("per_token_quant_int8_cpu", &per_token_quant_int8_cpu, "dynamic quantization for CPU");
// gemm
m.def("weight_packed_linear", &weight_packed_linear, "weight packed linear for intel AMX");
// igemm
m.def("int8_scaled_mm_cpu", &int8_scaled_mm_cpu, "int8 weight packed linear for intel AMX");
// quant + igemm
m.def(
"int8_scaled_mm_with_quant", &int8_scaled_mm_with_quant, "fused per row quant and int8 scaled mm for intel AMX");
// bmm
m.def("bmm_cpu", &bmm_cpu, "bmm kernel for intel AMX");
// moe
m.def("fused_experts_cpu", &fused_experts_cpu, "fused moe kernel for CPU");
// weight absorption
m.def("qkv_proj_with_rope", &qkv_proj_with_rope, "fused qkv projection kernel with weight absorption for intel AMX");
// shared expert
m.def("shared_expert_cpu", &shared_expert_cpu, "shared expert kernel for CPU");
// all reduce
m.def("initialize", &initialize, "shared memory initialization for CPU");
m.def("shm_allreduce", &shm_allreduce, "low latency all_reduce implementation for CPU");
m.def("shm_allgather", &shm_allgather, "low latency all_gather implementation for CPU");
// rope
m.def("rotary_position_embedding_cpu", &rotary_position_embedding_cpu, "rotary position embedding for CPU");
}
#pragma once
#if defined(__AVX512F__) && defined(__AVX512BF16__) && defined(__AMX_BF16__)
#define CPU_CAPABILITY_AVX512
#endif
#include <ATen/cpu/vec/functional.h>
#include <ATen/cpu/vec/vec.h>
namespace {
using namespace at::vec;
template <typename scalar_t, typename std::enable_if_t<is_reduced_floating_point_v<scalar_t>, int> = 0>
inline Vectorized<scalar_t> convert_from_float_ext(const Vectorized<float>& a, const Vectorized<float>& b) {
return at::vec::convert_from_float<scalar_t>(a, b);
}
#if defined(CPU_CAPABILITY_AVX512)
// `at::vec::convert_from_float<>` from PyTorch doesn't have avx512-bf16 intrinsics
// use native instruction for bfloat16->float32 conversion
template <>
inline Vectorized<at::BFloat16>
convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorized<float>& b) {
return (__m512i)(_mm512_cvtne2ps_pbh(__m512(b), __m512(a)));
}
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC))
#endif
// vector to scalar reduction
#if defined(CPU_CAPABILITY_AVX512) && 0
inline float vec_reduce_sum(const Vectorized<float>& a) {
return _mm512_reduce_add_ps(__m512(a));
}
inline float vec_reduce_max(const Vectorized<float>& a) {
return _mm512_reduce_max_ps(__m512(a));
}
#else
inline float vec_reduce_sum(const Vectorized<float>& a) {
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return x + y; }, a);
}
inline float vec_reduce_max(const Vectorized<float>& a) {
return vec_reduce_all([](Vectorized<float>& x, Vectorized<float>& y) { return maximum(x, y); }, a);
}
#endif
// https://github.com/InternLM/lmdeploy/blob/086481ed84b59bee3b8e4274e5fc69620040c048/lmdeploy/pytorch/kernels/cuda/w8a8_triton_kernels.py#L282
template <typename scalar_t>
inline void
quantize_row_int8(uint8_t* __restrict__ Aq, float& As, const scalar_t* __restrict__ A, int64_t K, float eps = 1e-7) {
float amax = 0.f; // absolute max
for (int64_t k = 0; k < K; ++k) {
const float val = static_cast<float>(A[k]);
amax = std::max(amax, std::abs(val));
}
amax = std::max(amax, eps);
const float scale = amax / 127;
const float inv_scale = 127 / amax;
for (int64_t k = 0; k < K; ++k) {
const float val = static_cast<float>(A[k]) * inv_scale;
Aq[k] = (uint8_t)(std::round(val)) + 128;
}
As = scale;
}
#if defined(CPU_CAPABILITY_AVX512)
template <>
inline void quantize_row_int8<at::BFloat16>(
uint8_t* __restrict__ Aq, float& As, const at::BFloat16* __restrict__ A, int64_t K, float eps) {
const __m512 signBit = _mm512_set1_ps(-0.0f);
const __m512i off = _mm512_set1_epi32(128);
// K is 32x, no remainder
float amax = 0.f;
__m512 vamax0 = _mm512_set1_ps(0.f);
__m512 vamax1 = _mm512_set1_ps(0.f);
for (int64_t k = 0; k < K; k += 32) {
__m512i va = _mm512_loadu_si512((void*)(A + k));
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
vamax0 = _mm512_max_ps(vamax0, _mm512_andnot_ps(signBit, va0));
vamax1 = _mm512_max_ps(vamax1, _mm512_andnot_ps(signBit, va1));
}
amax = _mm512_reduce_max_ps(_mm512_max_ps(vamax0, vamax1));
amax = std::max(amax, eps);
const float scale = amax / 127;
const float inv_scale = 127 / amax;
const __m512 vd = _mm512_set1_ps(inv_scale);
for (int64_t k = 0; k < K; k += 32) {
__m512i va = _mm512_loadu_si512((void*)(A + k));
__m512 va0 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 0));
__m512 va1 = CVT_BF16_TO_FP32(_mm512_extracti32x8_epi32(va, 1));
va0 = _mm512_mul_ps(va0, vd);
va1 = _mm512_mul_ps(va1, vd);
va0 = _mm512_roundscale_ps(va0, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
va1 = _mm512_roundscale_ps(va1, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC));
__m128i i0 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va0), off));
__m128i i1 = _mm512_cvtepi32_epi8(_mm512_add_epi32(_mm512_cvtps_epi32(va1), off));
_mm256_storeu_si256(reinterpret_cast<__m256i*>(Aq + k), _mm256_set_m128i(i1, i0));
}
As = scale;
}
#endif
} // anonymous namespace
# Copyright 2025 SGLang Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import os
import shutil
import sys
from pathlib import Path
import torch
from setuptools import find_packages, setup
from setuptools.command.build_py import build_py
from torch.utils.cpp_extension import BuildExtension, CppExtension
root = Path(__file__).parent.resolve()
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
def _get_version():
with open(root / "pyproject.toml") as f:
for line in f:
if line.startswith("version"):
return line.split("=")[1].strip().strip('"')
operator_namespace = "sgl_kernel"
include_dirs = []
sources = [
"csrc/cpu/activation.cpp",
"csrc/cpu/bmm.cpp",
"csrc/cpu/decode.cpp",
"csrc/cpu/extend.cpp",
"csrc/cpu/gemm.cpp",
"csrc/cpu/gemm_int8.cpp",
"csrc/cpu/moe.cpp",
"csrc/cpu/moe_int8.cpp",
"csrc/cpu/norm.cpp",
"csrc/cpu/qkv_proj.cpp",
"csrc/cpu/topk.cpp",
"csrc/cpu/interface.cpp",
"csrc/cpu/shm.cpp",
"csrc/cpu/torch_extension_cpu.cpp",
]
extra_compile_args = {
"cxx": [
"-O3",
"-Wno-unknown-pragmas",
"-march=native",
"-fopenmp",
]
}
libraries = ["c10", "torch", "torch_python"]
cmdclass = {
"build_ext": BuildExtension.with_options(use_ninja=True),
}
Extension = CppExtension
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
ext_modules = [
Extension(
name="sgl_kernel.common_ops",
sources=sources,
include_dirs=include_dirs,
extra_compile_args=extra_compile_args,
libraries=libraries,
extra_link_args=extra_link_args,
py_limited_api=True,
),
]
setup(
name="sgl-kernel",
version=_get_version(),
packages=find_packages(where="python"),
package_dir={"": "python"},
ext_modules=ext_modules,
cmdclass=cmdclass,
options={"bdist_wheel": {"py_limited_api": "cp39"}},
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment