Unverified Commit a73c4df4 authored by Ma Mingfei's avatar Ma Mingfei Committed by GitHub
Browse files

Add optimized native kernels in sgl-kernel (#5150)


Co-authored-by: default avatarChunyuan WU <chunyuan.wu@intel.com>
Co-authored-by: default avatarYanbingJiang <yanbing.jiang@intel.com>
Co-authored-by: default avatarblzheng <beilei.zheng@intel.com>
parent 89a55418
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t, typename func_t, typename vec_func_t>
void act_and_mul_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
int64_t num_tokens,
int64_t dim,
const func_t& f,
const vec_func_t& vf) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int64_t kVecSize = bVec::size();
at::parallel_for(0, num_tokens, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
const scalar_t* __restrict__ input_ptr = input + i * 2 * dim;
const scalar_t* __restrict__ input_other_ptr = input_ptr + dim;
scalar_t* __restrict__ output_ptr = output + i * dim;
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= dim - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec y_bvec = bVec::loadu(input_other_ptr + d);
fVec y_fvec0, y_fvec1;
std::tie(y_fvec0, y_fvec1) = at::vec::convert_to_float(y_bvec);
x_fvec0 = vf(x_fvec0);
x_fvec1 = vf(x_fvec1);
x_fvec0 = x_fvec0 * y_fvec0;
x_fvec1 = x_fvec1 * y_fvec1;
x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(output_ptr + d);
}
#pragma GCC unroll 4
for (; d < dim; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float y_val = static_cast<float>(input_other_ptr[d]);
output_ptr[d] = f(x_val) * y_val;
}
}
});
}
} // anonymous namespace
// input : {num_tokens, 2 * d}
// output : {num_tokens, d}
at::Tensor silu_and_mul_cpu(at::Tensor& input) {
RECORD_FUNCTION("sgl-kernel::silu_and_mul_cpu", std::vector<c10::IValue>({input}));
auto sizes = input.sizes().vec();
int64_t last_dim = input.ndimension() - 1;
int64_t d = sizes[last_dim] / 2;
sizes[last_dim] = d;
int64_t num_tokens = input.numel() / input.size(-1);
at::Tensor out = at::empty(sizes, input.options());
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "silu_and_mul", [&] {
using Vec = at::vec::Vectorized<float>;
act_and_mul_kernel_impl(
out.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
num_tokens,
d,
[](float x) { return x / (1.f + std::exp(-x)); },
[](Vec x) { return x / (Vec(1.f) + x.neg().exp()); });
});
return out;
}
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t>
void bmm_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
int64_t B,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideB,
int64_t mat1_strideM,
int64_t out_strideB,
int64_t out_strideM,
float scale = 0.f) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// mat2 contiguous in [B, N, K]
int64_t mat2_strideB = N * K;
int64_t mat2_strideN = K;
const bool use_brgemm = can_use_brgemm<scalar_t>(M);
// parallel on [B, MB, NB]
at::parallel_for(0, B * MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, mb{0}, nb{0};
data_index_init(begin, bs, B, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t>(
/* A */ mat1 + bs * mat1_strideB + mb_start * mat1_strideM,
/* B */ mat2 + bs * mat2_strideB + nb_start * mat2_strideN /* nb * BLOCK_N * K */,
/* C */ out + bs * out_strideB + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
// move to the next index
data_index_step(bs, B, mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
}
} // anonymous namespace
// mat1 : [B, M, K]
// mat2 : [B, N, K] or [B, OC, IC]
// out : [B, M, N]
// scale: [] 0-dim tensor for per tensor quant
//
void bmm_cpu(at::Tensor& out, at::Tensor& mat1, at::Tensor& mat2, bool is_vnni, std::optional<at::Tensor>& scale) {
RECORD_FUNCTION("sgl-kernel::bmm_cpu", std::vector<c10::IValue>({out, mat1, mat2}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
// input and out could be non-contiguous
// weight needs to be contiguous in [OC, IC] order
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(out);
CHECK_INPUT(mat2);
CHECK_DIM(3, out);
CHECK_DIM(3, mat1);
CHECK_DIM(3, mat2);
int64_t B = mat1.size(0);
int64_t M = mat1.size(1);
int64_t N = mat2.size(1);
int64_t K = mat1.size(2);
TORCH_CHECK(!scale.has_value(), "bmm: do not support fp8 weight for now.")
TORCH_CHECK(N % 32 == 0, "tinygemm requires N to be 32x.");
int64_t mat1_strideB = mat1.stride(0);
int64_t mat1_strideM = mat1.stride(1);
int64_t out_strideB = out.stride(0);
int64_t out_strideM = out.stride(1);
// check shapes
TORCH_CHECK(mat2.size(0) == B && mat2.size(2) == K, "bmm: mat2 shape mismatch!");
TORCH_CHECK(out.size(0) == B && out.size(1) == M, "bmm: out shape mismatch!");
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "bmm_kernel_impl", [&] {
bmm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
B,
M,
N,
K,
mat1_strideB,
mat1_strideM,
out_strideB,
out_strideM);
});
}
#pragma once
#include <ATen/ATen.h>
#include <ATen/Parallel.h>
#include <ATen/record_function.h>
#if defined(_OPENMP)
#include <omp.h>
#endif
namespace {
// dispatch bool
#define AT_DISPATCH_BOOL(BOOL_V, BOOL_NAME, ...) \
[&] { \
if (BOOL_V) { \
constexpr bool BOOL_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr bool BOOL_NAME = false; \
return __VA_ARGS__(); \
} \
}()
// dispatch: bfloat16, float16, int8_t
#define CPU_DISPATCH_PACKED_TYPES(TYPE, ...) \
[&] { \
switch (TYPE) { \
case at::ScalarType::BFloat16: { \
using packed_t = at::BFloat16; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Half: { \
using packed_t = at::Half; \
return __VA_ARGS__(); \
} \
case at::ScalarType::Char: { \
using packed_t = int8_t; \
return __VA_ARGS__(); \
} \
default: \
TORCH_CHECK(false, "Unsupported floating data type.\n"); \
} \
}()
#define UNUSED(x) (void)(x)
#define CHECK_CPU(x) TORCH_CHECK(x.device().type() == at::kCPU, #x " must be a CPU tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimention")
#define CHECK_INPUT(x) \
CHECK_CPU(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
CHECK_CPU(x); \
CHECK_LAST_DIM_CONTIGUOUS(x)
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
// parallel routines
constexpr int GRAIN_SIZE = 1024;
template <typename T, typename std::enable_if<std::is_integral<T>::value, int>::type = 0>
inline T div_up(T x, T y) {
return (x + y - 1) / y;
}
template <typename T>
inline void balance211(T n, T nth, T ith, T& n_start, T& n_end) {
#if 0
// onednn partition pattern
T& n_my = n_end;
if (nth <= 1 || n == 0) {
n_start = 0;
n_my = n;
} else {
T n1 = div_up(n, nth);
T n2 = n1 - 1;
T T1 = n - n2 * nth;
n_my = ith < T1 ? n1 : n2;
n_start = ith <= T1 ? ith*n1 : T1 * n1 + (ith - T1) * n2;
}
n_end += n_start;
#else
// pytorch aten partition pattern
T n_my = div_up(n, nth);
n_start = ith * n_my;
n_end = std::min(n_start + n_my, n);
#endif
}
template <typename func_t>
inline void parallel_for(int n, const func_t& f) {
#if defined(_OPENMP)
#pragma omp parallel
{
int nth = omp_get_num_threads();
int ith = omp_get_thread_num();
int tbegin, tend;
balance211(n, nth, ith, tbegin, tend);
f(tbegin, tend);
}
#else
f(0, n);
#endif
}
// data indexing for dimension collapse
template <typename T>
inline T data_index_init(T offset) {
return offset;
}
template <typename T, typename... Args>
inline T data_index_init(T offset, T& x, const T& X, Args&&... args) {
offset = data_index_init(offset, std::forward<Args>(args)...);
x = offset % X;
return offset / X;
}
inline bool data_index_step() {
return true;
}
template <typename T, typename... Args>
inline bool data_index_step(T& x, const T& X, Args&&... args) {
if (data_index_step(std::forward<Args>(args)...)) {
x = ((x + 1) == X) ? 0 : (x + 1);
return x == 0;
}
return false;
}
// forced unroll for perf critical path
#if __has_attribute(always_inline)
#define ALWAYS_INLINE __attribute__((__always_inline__)) inline
#else
#define ALWAYS_INLINE inline
#endif
template <int n>
struct Unroll {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
Unroll<n - 1>{}(f, args...);
f(std::integral_constant<int, n - 1>{}, args...);
}
};
template <>
struct Unroll<1> {
template <typename Func, typename... Args>
ALWAYS_INLINE void operator()(const Func& f, Args... args) const {
f(std::integral_constant<int, 0>{}, args...);
}
};
} // anonymous namespace
This diff is collapsed.
This diff is collapsed.
#include "gemm.h"
#include "common.h"
#include "vec.h"
namespace {
// packed layout:
// quants {N, K} int8_t
// comp {N} int32_t
template <int BLOCK_N>
inline void s8s8_compensation(int8_t* __restrict__ packed, int K) {
#if defined(CPU_CAPABILITY_AVX512)
constexpr int COLS = BLOCK_N / 16;
__m512i vcomp[COLS];
for (int col = 0; col < COLS; ++col) {
vcomp[col] = _mm512_setzero_si512();
}
const int64_t offset = BLOCK_N * K;
const __m512i off = _mm512_set1_epi8(static_cast<char>(0x80));
for (int k = 0; k < K / 4; ++k) {
for (int col = 0; col < COLS; ++col) {
__m512i vb = _mm512_loadu_si512((const __m512i*)(packed + k * BLOCK_N * 4 + col * 64));
vcomp[col] = _mm512_dpbusd_epi32(vcomp[col], off, vb);
}
}
for (int col = 0; col < COLS; ++col) {
_mm512_storeu_si512((__m512i*)(packed + offset + col * 64), vcomp[col]);
}
#else
TORCH_CHECK(false, "s8s8_compensation not implemented!");
#endif
}
// convert to vnni format
// from [N, K] to [K/2, N, 2] for bfloat16 and float16
template <typename packed_t>
inline void pack_vnni(packed_t* __restrict__ packed, const packed_t* __restrict__ weight, int N, int K) {
const int VNNI_BLK = 2;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
}
template <>
inline void pack_vnni<int8_t>(int8_t* __restrict__ packed, const int8_t* __restrict__ weight, int N, int K) {
constexpr int BLOCK_N = block_size_n();
TORCH_CHECK(N == BLOCK_N);
const int VNNI_BLK = 4;
for (int n = 0; n < N; ++n) {
for (int k = 0; k < K / VNNI_BLK; ++k) {
for (int d = 0; d < VNNI_BLK; ++d) {
packed[k * N * VNNI_BLK + n * VNNI_BLK + d] = weight[n * K + k * VNNI_BLK + d];
}
}
}
s8s8_compensation<BLOCK_N>(packed, K);
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ input, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d);
fVec data1 = fVec::loadu(input + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d]);
}
}
template <typename scalar_t>
inline void copy_add_stub(
scalar_t* __restrict__ out, const float* __restrict__ input, const float* __restrict__ bias, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= size - kVecSize; d += kVecSize) {
fVec data0 = fVec::loadu(input + d) + fVec::loadu(bias + d);
fVec data1 = fVec::loadu(input + d + fVec::size()) + fVec::loadu(bias + d + fVec::size());
bVec out_vec = convert_from_float_ext<scalar_t>(data0, data1);
out_vec.store(out + d);
}
for (; d < size; ++d) {
out[d] = static_cast<scalar_t>(input[d] + bias[d]);
}
}
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::BFloat16* __restrict__ A,
const at::BFloat16* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512bh va;
__m512bh vb[COLS];
__m512 vc[ROWS * COLS];
auto loadc = [&](auto i) {
constexpr int col = i % COLS;
if constexpr (has_bias) {
vc[i] = _mm512_loadu_ps(bias + col * 16);
} else {
vc[i] = _mm512_set1_ps(0.f);
}
};
Unroll<ROWS * COLS>{}(loadc);
const int64_t K2 = K >> 1;
const int64_t lda2 = lda >> 1;
const int64_t ldb2 = ldb; // ldb * 2 >> 1;
const float* a_ptr = reinterpret_cast<const float*>(A);
const float* b_ptr = reinterpret_cast<const float*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = (__m512bh)(_mm512_set1_ps(a_ptr[row * lda2 + k]));
}
if constexpr (row == 0) {
vb[col] = (__m512bh)(_mm512_loadu_si512(b_ptr + k * ldb2 + col * 16));
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb2 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbf16_ps(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K2; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// for COLS = 2, 4 use 512bit store
// for COLS = 1, 3 use 256bit store
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)),
(__m512i)(_mm512_cvtne2ps_pbh(vc[row * COLS + col + 1], vc[row * COLS + col])));
}
} else {
_mm256_storeu_si256(reinterpret_cast<__m256i*>(C + row * ldc + col * 16), (__m256i)(_mm512_cvtneps_pbh(vc[i])));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 2, \
C + mb_start * ldc + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
struct brgemm {
static inline void apply(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int BLOCK_N = block_size_n();
at::native::cpublas::brgemm(M, N, K, lda, ldb, BLOCK_N, /* add_C */ false, A, B, Ctmp);
// copy from Ctmp to C
for (int64_t m = 0; m < M; ++m) {
if constexpr (has_bias) {
copy_add_stub(C + m * ldc, Ctmp + m * BLOCK_N, bias, N);
} else {
copy_stub(C + m * ldc, Ctmp + m * BLOCK_N, N);
}
}
}
};
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
if (brg) {
brgemm<scalar_t, has_bias>::apply(A, B, C, Ctmp, bias, M, N, K, lda, ldb, ldc);
return;
}
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void weight_packed_linear_kernel_impl(
scalar_t* __restrict__ out,
const scalar_t* __restrict__ mat1,
const scalar_t* __restrict__ mat2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t mat1_strideM,
int64_t out_strideM) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// use avx512-bf16 when a) M is small; b) dtype is bfloat16, otherwise use amx
const bool use_brgemm = (M > 4) || (!std::is_same_v<scalar_t, at::BFloat16>);
// parallel on [MB, NB]
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use float32 for accumulate
alignas(64) float Ctmp[BLOCK_M * BLOCK_N];
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(M - mb_start, BLOCK_M);
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * mat1_strideM,
/* B */ mat2 + nb_start * K /* nb * BLOCK_N * K */,
/* C */ out + mb_start * out_strideM + nb_start,
/* Ctmp*/ Ctmp,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ mat1_strideM,
/* ldb */ nb_size,
/* ldc */ out_strideM,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const TYPE* __restrict__ A, \
const TYPE* __restrict__ B, \
TYPE* __restrict__ C, \
float* __restrict__ Ctmp, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
at::Tensor convert_weight_packed(at::Tensor& weight) {
// for 3d moe weights
// weight : [E, OC, IC]
// w1 : [E, 2N, K]
// w2 : [E, K, N]
CHECK_INPUT(weight);
const int64_t ndim = weight.ndimension();
TORCH_CHECK(ndim == 2 || ndim == 3, "expect weight to be 2d or 3d, got ", ndim, "d tensor.");
const auto st = weight.scalar_type();
const int64_t E = ndim == 3 ? weight.size(0) : 1;
const int64_t OC = ndim == 3 ? weight.size(1) : weight.size(0);
const int64_t IC = ndim == 3 ? weight.size(2) : weight.size(1);
// we handle 2 TILE_N at a time.
TORCH_CHECK(OC % TILE_N == 0, "invalid weight out features ", OC);
TORCH_CHECK(IC % TILE_K == 0, "invalid weight input features ", IC);
constexpr int64_t BLOCK_N = block_size_n();
const int64_t NB = div_up(OC, BLOCK_N);
// use phony sizes here [E, OC, IC], for each [E], [OC, IC] -> [IC / 2, OC, 2]
auto packed_weight = at::empty({}, weight.options());
const int64_t stride = OC * IC;
TORCH_CHECK(
st == at::kBFloat16 || st == at::kHalf || st == at::kChar, "expect weight to be bfloat16, float16 or int8.");
CPU_DISPATCH_PACKED_TYPES(st, [&] {
// adjust most inner dimension size
const int packed_row_size = get_row_size<packed_t>(IC);
auto sizes = weight.sizes().vec();
sizes[ndim - 1] = packed_row_size;
packed_weight.resize_(sizes);
const packed_t* w_data = weight.data_ptr<packed_t>();
packed_t* packed_data = packed_weight.data_ptr<packed_t>();
// parallel on {E, NB}
at::parallel_for(0, E * NB, 0, [&](int64_t begin, int64_t end) {
int64_t e{0}, nb{0};
data_index_init(begin, e, E, nb, NB);
for (int64_t i = begin; i < end; ++i) {
UNUSED(i);
int64_t n = nb * BLOCK_N;
int64_t n_size = std::min(BLOCK_N, OC - n);
pack_vnni<packed_t>(
packed_data + e * OC * packed_row_size + n * packed_row_size, w_data + e * stride + n * IC, n_size, IC);
// move to the next index
data_index_step(e, E, nb, NB);
}
});
});
return packed_weight;
}
// mat1 : [M, K]
// mat2 : [N, K]
// bias : [N]
// out : [M, N]
//
at::Tensor weight_packed_linear(at::Tensor& mat1, at::Tensor& mat2, std::optional<at::Tensor>& bias, bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::weight_packed_linear", std::vector<c10::IValue>({mat1, mat2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat2.size(1);
CHECK_EQ(mat1.size(1), K);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
auto out = at::empty({M, N}, mat1.options());
// strides
int64_t mat1_strideM = mat1.stride(0);
int64_t out_strideM = out.stride(0);
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(mat1.scalar_type(), "weight_packed_linear_kernel_impl", [&] {
weight_packed_linear_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<scalar_t>(),
packed_w.data_ptr<scalar_t>(),
bias_data,
M,
N,
K,
mat1_strideM,
out_strideM);
});
return out;
}
#pragma once
#include <ATen/native/CPUBlas.h>
// amx-bf16
#define TILE_M 16
#define TILE_N 16
#define TILE_K 32
// block size for AMX gemm
constexpr int block_size_m() {
return 2 * TILE_M;
}
constexpr int block_size_n() {
return 2 * TILE_N;
}
// define threshold using brgemm (intel AMX)
template <typename T>
inline bool can_use_brgemm(int M);
template <>
inline bool can_use_brgemm<at::BFloat16>(int M) {
return M > 4;
}
template <>
inline bool can_use_brgemm<at::Half>(int M) {
return true;
}
// TODO: add u8s8 brgemm, this requires PyTorch 2.7
template <>
inline bool can_use_brgemm<int8_t>(int M) {
return false;
}
// work around compiler internal error
#define BLOCK_K 128 // 4 * TILE_K
// adjust leading dimension size for K
template <typename T>
inline int64_t get_row_size(int64_t K) {
return K;
}
template <>
inline int64_t get_row_size<int8_t>(int64_t K) {
return K + sizeof(int32_t);
}
inline int64_t get_row_size(int64_t K, bool use_int8_w8a8) {
return use_int8_w8a8 ? K + sizeof(int32_t) : K;
}
// pack weight to vnni format
at::Tensor convert_weight_packed(at::Tensor& weight);
// moe implementations for int8 w8a8
template <typename scalar_t>
void fused_experts_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
scalar_t* __restrict__ ic2,
uint8_t* __restrict__ A_tmp,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const float* __restrict__ topk_weights,
const int32_t* __restrict__ sorted_ids,
const int32_t* __restrict__ expert_ids,
const int32_t* __restrict__ offsets,
int64_t M,
int64_t N,
int64_t K,
int64_t E,
int64_t topk,
int64_t num_tokens_post_pad);
// shared expert implememntation for int8 w8a8
template <typename scalar_t>
void shared_expert_int8_kernel_impl(
scalar_t* __restrict__ output,
scalar_t* __restrict__ ic1,
float* __restrict__ C_tmp,
uint8_t* __restrict__ Aq_tmp,
float* __restrict__ As_tmp,
const scalar_t* __restrict__ input,
const int8_t* __restrict__ packed_w1,
const int8_t* __restrict__ packed_w2,
const float* __restrict__ w1s,
const float* __restrict__ w2s,
const scalar_t* __restrict__ fused_experts_out,
float routed_scaling_factor,
int64_t M,
int64_t N,
int64_t K);
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const scalar_t* __restrict__ A,
const scalar_t* __restrict__ B,
scalar_t* __restrict__ C,
float* __restrict__ Ctmp,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg);
#include "common.h"
#include "gemm.h"
#include "vec.h"
namespace {
template <typename scalar_t, bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
TORCH_CHECK(false, "tinygemm_kernel_nn: scalar path not implemented!");
}
};
#if defined(CPU_CAPABILITY_AVX512)
template <bool has_bias, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::BFloat16, has_bias, BLOCK_M, BLOCK_N> {
static inline void apply(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
at::BFloat16* __restrict__ C,
const float* __restrict__ As,
const float* __restrict__ Bs,
const int32_t* __restrict__ Bcomp,
const float* __restrict__ bias,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
static_assert(COLS % 2 == 0);
// prefetch distance
constexpr int PREFETCH_SIZE_K = 0;
__m512i va;
__m512i vb[COLS];
__m512i vc[ROWS * COLS];
__m512i vcomp[COLS];
__m512 vd0;
__m512 vd1[COLS];
// oops! 4x4 spills but luckly we use 4x2
__m512 vbias[COLS];
// [NOTE]: s8s8 igemm compensation in avx512-vnni
//
// avx512-vnni has no s8s8, so we need to change s8s8 to u8s8 with compensate:
//
// a * b = (a + 128) * b - 128 * b
// s s u s u s
//
// 1) 128 * b is pre-computed when packing B to vnni formats
// 2) a + 128 is fused when dynamically quantize A
//
auto loadc = [&](auto i) { vc[i] = _mm512_set1_epi32(0); };
Unroll<ROWS * COLS>{}(loadc);
const int64_t K4 = K >> 2;
const int64_t lda4 = lda >> 2;
const int64_t ldb4 = ldb; // ldb * 4 >> 2;
const int32_t* a_ptr = reinterpret_cast<const int32_t*>(A);
const int32_t* b_ptr = reinterpret_cast<const int32_t*>(B);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_epi32(a_ptr[row * lda4 + k]);
}
if constexpr (row == 0) {
vb[col] = _mm512_loadu_si512(b_ptr + k * ldb4 + col * 16);
if constexpr (PREFETCH_SIZE_K > 0) {
_mm_prefetch(b_ptr + (k + PREFETCH_SIZE_K) * ldb4 + col * 16, _MM_HINT_T0);
}
}
vc[i] = _mm512_dpbusd_epi32(vc[i], va, vb[col]);
};
for (int64_t k = 0; k < K4; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
// load a scale
if constexpr (col == 0) {
vd0 = _mm512_set1_ps(As[row]);
}
// load b scale and vcomp per 2 vectors
// also load bias if any
if constexpr (row == 0) {
if constexpr (col % 2 == 0) {
vd1[col + 0] = _mm512_loadu_ps(Bs + col * 16);
vd1[col + 1] = _mm512_loadu_ps(Bs + col * 16 + 16);
vcomp[col + 0] = _mm512_loadu_si512(Bcomp + col * 16);
vcomp[col + 1] = _mm512_loadu_si512(Bcomp + col * 16 + 16);
if constexpr (has_bias) {
vbias[col + 0] = _mm512_loadu_ps(bias + col * 16);
vbias[col + 1] = _mm512_loadu_ps(bias + col * 16 + 16);
}
}
}
// for COLS = 2, 4 use 512bit store
if constexpr (col % 2 == 0) {
__m512 vc0 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 0], vcomp[col + 0]));
__m512 vc1 = _mm512_cvtepi32_ps(_mm512_sub_epi32(vc[row * COLS + col + 1], vcomp[col + 1]));
if constexpr (has_bias) {
vc0 = _mm512_fmadd_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0], vbias[col + 0]);
vc1 = _mm512_fmadd_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1], vbias[col + 1]);
} else {
vc0 = _mm512_mul_ps(_mm512_mul_ps(vc0, vd0), vd1[col + 0]);
vc1 = _mm512_mul_ps(_mm512_mul_ps(vc1, vd0), vd1[col + 1]);
}
_mm512_storeu_si512(
reinterpret_cast<__m512i*>((C + row * ldc + col * 16)), (__m512i)(_mm512_cvtne2ps_pbh(vc1, vc0)));
}
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, has_bias, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \
B + nb_start * 4, \
C + mb_start * ldc + nb_start, \
As + mb_start, \
Bs + nb_start, \
Bcomp + nb_start, \
has_bias ? bias + nb_start : nullptr, \
K, \
lda, \
ldb, \
ldc);
template <typename scalar_t, bool has_bias>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
// B compensation
const int32_t* Bcomp = reinterpret_cast<const int32_t*>(B + block_size_n() * K);
// pattern: 1-4-16
constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 64;
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
for (int64_t mb = 0; mb < MB; ++mb) {
int64_t mb_start = mb * BLOCK_M;
int64_t mb_size = std::min(BLOCK_M, M - mb_start);
for (int64_t nb = 0; nb < NB; ++nb) {
int64_t nb_start = nb * BLOCK_N;
int64_t nb_size = std::min(BLOCK_N, N - nb_start);
switch (mb_size << 4 | nb_size >> 4) {
// mb_size = 1
case 0x12:
LAUNCH_TINYGEMM_KERNEL_NN(1, 32);
break;
case 0x14:
LAUNCH_TINYGEMM_KERNEL_NN(1, 64);
break;
// mb_size = 2
case 0x22:
LAUNCH_TINYGEMM_KERNEL_NN(2, 32);
break;
case 0x24:
LAUNCH_TINYGEMM_KERNEL_NN(2, 64);
break;
// mb_size = 3
case 0x32:
LAUNCH_TINYGEMM_KERNEL_NN(3, 32);
break;
case 0x34:
LAUNCH_TINYGEMM_KERNEL_NN(3, 64);
break;
// mb_size = 4
case 0x42:
LAUNCH_TINYGEMM_KERNEL_NN(4, 32);
break;
case 0x44:
LAUNCH_TINYGEMM_KERNEL_NN(4, 64);
break;
default:
TORCH_CHECK(false, "Unexpected block size, ", mb_size, "x", "nb_size");
}
}
}
}
template <typename scalar_t>
void int8_scaled_mm_kernel_impl(
scalar_t* __restrict__ out,
const uint8_t* __restrict__ mat1,
const int8_t* __restrict__ mat2,
const float* __restrict__ scales1,
const float* __restrict__ scales2,
const float* __restrict__ bias,
int64_t M,
int64_t N,
int64_t K) {
constexpr int64_t BLOCK_M = block_size_m();
constexpr int64_t BLOCK_N = block_size_n();
const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N);
// TODO: brgemm u8s8 depends on PyTorch 2.7 release.
const bool use_brgemm = false;
// K + 4 after compensation
const int64_t packed_row_size = get_row_size<int8_t>(K);
AT_DISPATCH_BOOL(bias != nullptr, has_bias, [&] {
at::parallel_for(0, MB * NB, 0, [&](int64_t begin, int64_t end) {
int64_t mb{0}, nb{0};
data_index_init(begin, mb, MB, nb, NB);
// for brgemm, use int32_t for accumulate
alignas(64) int32_t Ctmp[BLOCK_M * BLOCK_N];
for (int i = begin; i < end; ++i) {
UNUSED(i);
int mb_start = mb * BLOCK_M;
int mb_size = std::min(M - mb_start, BLOCK_M);
int nb_start = nb * BLOCK_N;
int nb_size = std::min(N - nb_start, BLOCK_N);
tinygemm_kernel<scalar_t, has_bias>(
/* A */ mat1 + mb_start * K,
/* B */ mat2 + nb_start * packed_row_size /* nb * BLOCK_N * (K + 4) */,
/* C */ out + mb_start * N + nb_start,
/* Ctmp*/ Ctmp,
/* As */ scales1 + mb_start,
/* Bs */ scales2 + nb_start,
/* bias*/ bias + nb_start,
/* M */ mb_size,
/* N */ nb_size,
/* K */ K,
/* lda */ K,
/* ldb */ nb_size,
/* ldc */ N,
/* brg */ use_brgemm);
// move to the next index
data_index_step(mb, MB, nb, NB);
}
if (use_brgemm) {
at::native::cpublas::brgemm_release();
}
});
});
}
} // anonymous namespace
// tinygemm interface
template <typename scalar_t>
void tinygemm_kernel(
const uint8_t* __restrict__ A,
const int8_t* __restrict__ B,
scalar_t* __restrict__ C,
int32_t* __restrict__ Ctmp,
const float* __restrict__ As,
const float* __restrict__ Bs,
int64_t M,
int64_t N,
int64_t K,
int64_t lda,
int64_t ldb,
int64_t ldc,
bool brg) {
tinygemm_kernel<scalar_t, false>(A, B, C, Ctmp, As, Bs, nullptr, M, N, K, lda, ldb, ldc, brg);
}
#define INSTANTIATE_TINYGEMM_TEMPLATE(TYPE) \
template void tinygemm_kernel<TYPE>( \
const uint8_t* __restrict__ A, \
const int8_t* __restrict__ B, \
TYPE* __restrict__ C, \
int32_t* __restrict__ Ctmp, \
const float* __restrict__ As, \
const float* __restrict__ Bs, \
int64_t M, \
int64_t N, \
int64_t K, \
int64_t lda, \
int64_t ldb, \
int64_t ldc, \
bool brg)
INSTANTIATE_TINYGEMM_TEMPLATE(at::BFloat16);
INSTANTIATE_TINYGEMM_TEMPLATE(at::Half);
std::tuple<at::Tensor, at::Tensor> per_token_quant_int8_cpu(at::Tensor& A) {
RECORD_FUNCTION("sgl-kernel::per_token_quant_int8_cpu", std::vector<c10::IValue>({A}));
CHECK_LAST_DIM_CONTIGUOUS_INPUT(A);
CHECK_DIM(2, A);
int64_t M = A.size(0);
int64_t K = A.size(1);
int64_t lda = A.stride(0);
const auto st = A.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "per_token_quant_int8: expect A to be bfloat16 or half.");
auto Aq = at::empty({M, K}, A.options().dtype(at::kByte));
auto As = at::empty({M}, A.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(st, "per_token_quant_int8", [&] {
uint8_t* __restrict__ Aq_data = Aq.data_ptr<uint8_t>();
float* __restrict__ As_data = As.data_ptr<float>();
const scalar_t* __restrict__ A_data = A.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
});
return std::make_tuple(Aq, As);
}
// weight : static, per-channel, symmetric
// activation : dynamic, per-token, symmetric
//
// mat1 : [M, K]
// mat2 : [N, K]
// scales1 : [M]
// scales2 : [N]
// bias : [N]
// out : [M, N]
//
at::Tensor int8_scaled_mm_cpu(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales1,
at::Tensor& scales2,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales1, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales1);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales1.numel(), M);
CHECK_EQ(scales2.numel(), N);
TORCH_CHECK(mat1.scalar_type() == at::kByte, "int8_scaled_mm: expect mat1 to be uint8.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm: expect mat2 to be int8.");
TORCH_CHECK(
scales1.scalar_type() == at::kFloat && scales2.scalar_type() == at::kFloat,
"int8_scaled_mm: expect scales to be float32.");
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_kernel_impl", [&] {
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
mat1.data_ptr<uint8_t>(),
packed_w.data_ptr<int8_t>(),
scales1.data_ptr<float>(),
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}
// fused `per_token_quant_int8_cpu` and `int8_scaled_mm_cpu`
at::Tensor int8_scaled_mm_with_quant(
at::Tensor& mat1,
at::Tensor& mat2,
at::Tensor& scales2,
std::optional<at::Tensor>& bias,
at::ScalarType out_dtype,
bool is_vnni) {
RECORD_FUNCTION("sgl-kernel::int8_scaled_mm_cpu", std::vector<c10::IValue>({mat1, mat2, scales2, bias}));
auto packed_w = is_vnni ? mat2 : convert_weight_packed(mat2);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(mat1);
CHECK_INPUT(mat2);
CHECK_INPUT(scales2);
CHECK_DIM(2, mat1);
CHECK_DIM(2, mat2);
int64_t M = mat1.size(0);
int64_t N = mat2.size(0);
int64_t K = mat1.size(1);
int64_t lda = mat1.stride(0);
// see [NOTE]: s8s8 igemm compensation in avx512-vnni
CHECK_EQ(mat2.size(1), (int64_t)(is_vnni ? K + sizeof(int32_t) : K));
CHECK_EQ(scales2.numel(), N);
const auto st = mat1.scalar_type();
TORCH_CHECK(st == at::kBFloat16 || st == at::kHalf, "int8_scaled_mm_with_quant: expect A to be bfloat16 or half.");
TORCH_CHECK(st == out_dtype, "int8_scaled_mm_with_quant: expect A has same dtype with out_dtype.");
TORCH_CHECK(mat2.scalar_type() == at::kChar, "int8_scaled_mm_with_quant: expect mat2 to be int8.");
TORCH_CHECK(scales2.scalar_type() == at::kFloat, "int8_scaled_mm_with_quant: expect scales to be float32.");
const int64_t buffer_size = M * K + M * sizeof(float);
auto buffer = at::empty({buffer_size}, mat1.options().dtype(at::kByte));
auto out = at::empty({M, N}, mat1.options().dtype(out_dtype));
const bool has_bias = bias.has_value();
const float* bias_data = nullptr;
if (has_bias) {
CHECK_EQ(bias.value().size(0), N);
bias_data = bias.value().data_ptr<float>();
}
AT_DISPATCH_REDUCED_FLOATING_TYPES(out_dtype, "int8_scaled_mm_with_quant_kernel_impl", [&] {
uint8_t* __restrict__ Aq_data = buffer.data_ptr<uint8_t>();
float* __restrict__ As_data = (float*)((void*)(Aq_data + M * K));
const scalar_t* __restrict__ A_data = mat1.data_ptr<scalar_t>();
at::parallel_for(0, M, 0, [&](int64_t begin, int64_t end) {
for (int64_t m = begin; m < end; ++m) {
quantize_row_int8<scalar_t>(Aq_data + m * K, As_data[m], A_data + m * lda, K);
}
});
int8_scaled_mm_kernel_impl<scalar_t>(
out.data_ptr<scalar_t>(),
Aq_data,
packed_w.data_ptr<int8_t>(),
As_data,
scales2.data_ptr<float>(),
bias_data,
M,
N,
K);
});
return out;
}
#include <ATen/record_function.h>
#include <torch/extension.h>
#include "shm.h"
// Communication settings
static int world_rank = -1;
static int world_size = -1;
static bool is_initialized = false;
static bool all_ranks_local_p = false;
void initialize(int size, int rank) {
if (is_initialized) {
return;
}
// Check whether all ranks is on the same physical machine.
// If true, we will use an SHM based low latency allreduce
auto ls_string = std::getenv("LOCAL_SIZE");
int ls = 0;
if (ls_string != NULL) {
ls = std::stoi(std::getenv("LOCAL_SIZE"));
}
if (size >= 1 && size == ls) {
all_ranks_local_p = true;
}
world_size = size;
world_rank = rank;
is_initialized = true;
auto addr_string = std::getenv("MASTER_ADDR");
if (addr_string == NULL) {
addr_string = "";
}
auto port_string = std::getenv("MASTER_PORT");
if (port_string == NULL) {
port_string = "";
}
if (all_ranks_local_p) {
shm_initialize(size, rank, addr_string, port_string);
}
}
void shm_allreduce(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, py::object op) {
RECORD_FUNCTION("sgl-kernel::shm_allreduce", std::vector<c10::IValue>({data}));
static py::object ReduceOp = py::module_::import("torch.distributed").attr("ReduceOp");
static auto ReduceOpSum = (int)py::int_(ReduceOp.attr("SUM").attr("value"));
TORCH_CHECK(py::int_(op.attr("value")) == ReduceOpSum, "Only torch.distributed.ReduceOp.SUM is supported");
auto numel = data.numel();
int data_size = 0;
bool data_type_fallback = false;
switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<torch::Tensor> tensors = {data};
process_group->allreduce(tensors)->wait();
} else {
all_reduce_outer_loop(data, numel, data_size);
}
return;
}
torch::Tensor shm_allgather(torch::Tensor& data, c10::intrusive_ptr<c10d::ProcessGroup> process_group, int dim) {
RECORD_FUNCTION("sgl-kernel::shm_allgather", std::vector<c10::IValue>({data}));
auto numel = data.numel();
int data_size = 0;
bool data_type_fallback = false;
switch (data.scalar_type()) {
case c10::ScalarType::BFloat16:
data_size = numel * 2;
break;
case c10::ScalarType::Float:
data_size = numel * 4;
break;
default:
data_type_fallback = true;
}
if (dim < 0) {
dim += data.dim();
}
if (data_type_fallback || !all_ranks_local_p) {
// Fallback to torch distributed allreduce
std::vector<std::vector<torch::Tensor>> output_tensors(1);
auto world_size = process_group->getSize();
for (int i = 0; i < world_size; i++) {
output_tensors[0].push_back(torch::empty_like(data));
}
std::vector<torch::Tensor> input_tensors = {data};
process_group->allgather(output_tensors, input_tensors)->wait();
return torch::cat(output_tensors[0], dim).contiguous();
}
std::vector<int64_t> result_shape = data.sizes().vec();
result_shape[dim] *= world_size;
torch::Tensor result_tensor = torch::empty(result_shape, data.options());
return all_gather(result_tensor, data, dim, numel, data_size);
}
This diff is collapsed.
This diff is collapsed.
#include "common.h"
#include "vec.h"
namespace {
// NB: avoid using `at::vec::map<>` on bfloat16 or half
template <typename scalar_t>
void rmsnorm_kernel_impl(
scalar_t* __restrict__ output,
const scalar_t* __restrict__ input,
const scalar_t* __restrict__ weight,
int64_t batch_size,
int64_t hidden_size,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ out_ptr = output + i * hidden_size;
const scalar_t* __restrict__ input_ptr = input + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
sum_val += x_val * x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec w_bvec = bVec::loadu(weight + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(out_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float w_val = static_cast<float>(weight[d]);
out_ptr[d] = static_cast<scalar_t>(x_val * rsqrt_var * w_val);
}
}
});
}
template <typename scalar_t>
void fused_add_rmsnorm_kernel_impl(
scalar_t* __restrict__ input,
scalar_t* __restrict__ residual,
const scalar_t* __restrict__ weight,
float* __restrict__ buffer,
int64_t batch_size,
int64_t hidden_size,
float eps = 1e-5) {
using bVec = at::vec::Vectorized<scalar_t>;
using fVec = at::vec::Vectorized<float>;
constexpr int kVecSize = bVec::size();
at::parallel_for(0, batch_size, 0, [&](int64_t begin, int64_t end) {
int tid = at::get_thread_num();
float* __restrict__ buffer_ptr = buffer + tid * hidden_size;
for (int64_t i = begin; i < end; ++i) {
// local ptrs
scalar_t* __restrict__ input_ptr = input + i * hidden_size;
scalar_t* __restrict__ residual_ptr = residual + i * hidden_size;
fVec sum_fvec = fVec(float(0));
float sum_val = float(0);
int64_t d;
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
bVec x_bvec = bVec::loadu(input_ptr + d);
fVec x_fvec0, x_fvec1;
std::tie(x_fvec0, x_fvec1) = at::vec::convert_to_float(x_bvec);
bVec r_bvec = bVec::loadu(residual_ptr + d);
fVec r_fvec0, r_fvec1;
std::tie(r_fvec0, r_fvec1) = at::vec::convert_to_float(r_bvec);
x_fvec0 += r_fvec0;
x_fvec1 += r_fvec1;
bVec out_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
out_bvec.store(residual_ptr + d);
sum_fvec += x_fvec0 * x_fvec0;
sum_fvec += x_fvec1 * x_fvec1;
x_fvec0.store(buffer_ptr + d);
x_fvec1.store(buffer_ptr + d + fVec::size());
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = static_cast<float>(input_ptr[d]);
float r_val = static_cast<float>(residual_ptr[d]);
x_val += r_val;
residual_ptr[d] = static_cast<scalar_t>(x_val);
sum_val += x_val * x_val;
buffer_ptr[d] = x_val;
}
sum_val += vec_reduce_sum(sum_fvec);
float rsqrt_var = float(1) / std::sqrt(sum_val / hidden_size + eps);
const fVec scale_fvec = fVec(rsqrt_var);
#pragma GCC unroll 4
for (d = 0; d <= hidden_size - kVecSize; d += kVecSize) {
fVec x_fvec0 = fVec::loadu(buffer_ptr + d);
fVec x_fvec1 = fVec::loadu(buffer_ptr + d + fVec::size());
bVec w_bvec = bVec::loadu(weight + d);
fVec w_fvec0, w_fvec1;
std::tie(w_fvec0, w_fvec1) = at::vec::convert_to_float(w_bvec);
x_fvec0 = x_fvec0 * scale_fvec * w_fvec0;
x_fvec1 = x_fvec1 * scale_fvec * w_fvec1;
bVec x_bvec = convert_from_float_ext<scalar_t>(x_fvec0, x_fvec1);
x_bvec.store(input_ptr + d);
}
#pragma GCC unroll 4
for (; d < hidden_size; ++d) {
float x_val = buffer_ptr[d] * rsqrt_var * static_cast<float>(weight[d]);
input_ptr[d] = x_val;
}
}
});
}
} // anonymous namespace
// input : {batch_size, hidden_size}
// weight: {hidden_size}
at::Tensor rmsnorm_cpu(at::Tensor& input, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::rmsnorm_cpu", std::vector<c10::IValue>({input, weight}));
CHECK_INPUT(input);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
CHECK_DIM(1, weight);
CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
at::Tensor output = at::empty_like(input);
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "rmsnorm_kernel", [&] {
rmsnorm_kernel_impl<scalar_t>(
output.data_ptr<scalar_t>(),
input.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
batch_size,
hidden_size,
eps);
});
return output;
}
// input : {batch_size, hidden_size}
// residual: {batch_size, hidden_size}
// weight : {hidden_size}
void fused_add_rmsnorm_cpu(at::Tensor& input, at::Tensor& residual, at::Tensor& weight, double eps) {
RECORD_FUNCTION("sgl-kernel::fused_add_rmsnorm_cpu", std::vector<c10::IValue>({input, residual, weight}));
CHECK_INPUT(input);
CHECK_INPUT(residual);
CHECK_INPUT(weight);
CHECK_DIM(2, input);
CHECK_DIM(2, residual);
CHECK_DIM(1, weight);
CHECK_EQ(input.size(0), residual.size(0));
CHECK_EQ(input.size(1), residual.size(1));
CHECK_EQ(input.size(1), weight.size(0));
int64_t batch_size = input.size(0);
int64_t hidden_size = input.size(1);
// allocate temp buffer to store x in float32 per thread
// TODO: implement a singleton for context
int64_t num_threads = at::get_num_threads();
at::Tensor buffer = at::empty({num_threads, hidden_size}, input.options().dtype(at::kFloat));
AT_DISPATCH_REDUCED_FLOATING_TYPES(input.scalar_type(), "fused_add_rmsnorm_kernel", [&] {
fused_add_rmsnorm_kernel_impl<scalar_t>(
input.data_ptr<scalar_t>(),
residual.data_ptr<scalar_t>(),
weight.data_ptr<scalar_t>(),
buffer.data_ptr<float>(),
batch_size,
hidden_size,
eps);
});
}
This diff is collapsed.
#include "common.h"
#include "vec.h"
namespace {
template <typename scalar_t>
void rope_kernel_impl(
scalar_t* __restrict__ q_pe_out,
scalar_t* __restrict__ k_pe_out,
int64_t* __restrict__ t_pos,
scalar_t* __restrict__ q_pe,
scalar_t* __restrict__ k_pe,
scalar_t* __restrict__ t_emb_pos,
int64_t seq_len,
int64_t num_head,
int64_t rotary_dim,
int64_t HR,
int64_t q_pe_stride_s,
int64_t out_stride_qs,
int64_t out_stride_ks,
int64_t HK,
int64_t k_pe_stride_s,
int64_t q_pe_stride_n,
int64_t out_stride_qn) {
int64_t COFF = HR / 2;
at::parallel_for(0, seq_len * num_head, GRAIN_SIZE / rotary_dim, [&](int64_t begin, int64_t end) {
int64_t seq{0}, head_id{0};
data_index_init(begin, seq, seq_len, head_id, num_head);
for (int64_t i = begin; i < end; ++i) {
int64_t in_offset_q = seq * q_pe_stride_s + head_id * q_pe_stride_n;
int64_t out_offset_q = seq * out_stride_qs + head_id * out_stride_qn;
int64_t out_offset_k = seq * out_stride_ks;
int64_t p = 0;
scalar_t* sin_start = nullptr;
scalar_t* cos_start = nullptr;
// step 0) get the rotary position embedding for the current position
p = t_pos[seq];
sin_start = t_emb_pos + p * HR + COFF;
cos_start = t_emb_pos + p * HR;
// step 1) apply_rotary_pos_emb for the rotary_dim elements in every
// head of query/key
for (int64_t h = 0; h < rotary_dim; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
scalar_t in1 = q_pe[in_offset_q + h];
scalar_t in2 = q_pe[in_offset_q + h + 1];
scalar_t out1 = in1 * cos - in2 * sin;
scalar_t out2 = in2 * cos + in1 * sin;
q_pe_out[out_offset_q + h] = out1;
q_pe_out[out_offset_q + h + 1] = out2;
}
for (int64_t h = 0; h < HK; h += 2) {
scalar_t cos = cos_start[h >> 1];
scalar_t sin = sin_start[h >> 1];
int64_t k_pe_offset = seq * k_pe_stride_s;
scalar_t in1_k = k_pe[k_pe_offset + h];
scalar_t in2_k = k_pe[k_pe_offset + h + 1];
scalar_t out1_k = in1_k * cos - in2_k * sin;
scalar_t out2_k = in2_k * cos + in1_k * sin;
k_pe_out[out_offset_k + h] = out1_k;
k_pe_out[out_offset_k + h + 1] = out2_k;
}
// move to the next index
data_index_step(seq, seq_len, head_id, num_head);
}
});
}
} // namespace
std::tuple<at::Tensor, at::Tensor>
rotary_position_embedding_cpu(at::Tensor& t_pos, at::Tensor& q_pe, at::Tensor& k_pe, at::Tensor& t_emb_pos) {
RECORD_FUNCTION(
"sgl-kernel::rotary_position_embedding_cpu", std::vector<c10::IValue>({t_pos, q_pe, k_pe, t_emb_pos}));
CHECK_INPUT(t_pos);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_pe);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_pe);
CHECK_INPUT(t_emb_pos);
CHECK_DIM(1, t_pos);
CHECK_DIM(3, q_pe);
CHECK_DIM(3, k_pe);
CHECK_DIM(2, t_emb_pos);
int64_t seq_len = q_pe.size(0);
int64_t num_head = q_pe.size(1);
int64_t rotary_dim = q_pe.size(2);
int64_t HK = k_pe.size(2);
int64_t HR = t_emb_pos.size(1);
CHECK_EQ(HR, rotary_dim);
CHECK_EQ(k_pe.size(0), seq_len);
CHECK_EQ(k_pe.size(1), 1);
CHECK_EQ(t_pos.size(0), seq_len);
CHECK_EQ(HK, rotary_dim);
at::Tensor q_pe_out = at::empty_like(q_pe);
at::Tensor k_pe_out = at::empty_like(k_pe);
int64_t q_pe_stride_s = q_pe.stride(0);
int64_t q_pe_stride_n = q_pe.stride(1);
int64_t k_pe_stride_s = k_pe.stride(0);
int64_t out_stride_qs = q_pe_out.stride(0);
int64_t out_stride_qn = q_pe_out.stride(1);
int64_t out_stride_ks = k_pe_out.stride(0);
const auto input_dtype = q_pe.scalar_type();
TORCH_CHECK(t_pos.scalar_type() == at::kLong, "expect positions to be int64, got ", t_pos.scalar_type());
TORCH_CHECK(input_dtype == k_pe.scalar_type(), "q_pe and k_pe must have the same data type");
TORCH_CHECK(input_dtype == t_emb_pos.scalar_type(), "q_pe and t_emb_pos must have the same data type");
AT_DISPATCH_REDUCED_FLOATING_TYPES(input_dtype, "rotary_position_embedding_cpu", [&] {
rope_kernel_impl<scalar_t>(
q_pe_out.data_ptr<scalar_t>(),
k_pe_out.data_ptr<scalar_t>(),
t_pos.data_ptr<int64_t>(),
q_pe.data_ptr<scalar_t>(),
k_pe.data_ptr<scalar_t>(),
t_emb_pos.data_ptr<scalar_t>(),
seq_len,
num_head,
rotary_dim,
HR,
q_pe_stride_s,
out_stride_qs,
out_stride_ks,
HK,
k_pe_stride_s,
q_pe_stride_n,
out_stride_qn);
});
return std::make_tuple(q_pe_out, k_pe_out);
}
This diff is collapsed.
#include <torch/torch.h>
#include <torch/csrc/distributed/c10d/ProcessGroup.hpp>
#ifndef __SHM_COLLECTIVES__
#define __SHM_COLLECTIVES__
#define VECTOR_LENGTH_IN_BYTES 32
void shm_initialize(int size, int rank, char* addr_string, char* port_string);
void all_reduce_outer_loop(torch::Tensor& data, size_t numel, int data_size);
torch::Tensor& all_gather(torch::Tensor& result, torch::Tensor& data, int dim, size_t numel, int data_size);
#endif
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment