Unverified Commit 13fb8b54 authored by blzheng's avatar blzheng Committed by GitHub
Browse files

[CPU] Optimize FP16 decode_attention_cpu (#10652)

parent 81fd2b0e
...@@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding): ...@@ -540,7 +540,10 @@ class ParallelLMHead(VocabParallelEmbedding):
# We only support pack LMHead if it's not quantized. # We only support pack LMHead if it's not quantized.
if _is_cpu and _is_cpu_amx_available: if _is_cpu and _is_cpu_amx_available:
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16: if hasattr(self, "weight") and self.weight.dtype in [
torch.bfloat16,
torch.float16,
]:
self.quant_method = PackWeightMethod(weight_names=["weight"]) self.quant_method = PackWeightMethod(weight_names=["weight"])
if bias: if bias:
......
...@@ -308,6 +308,93 @@ struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> { ...@@ -308,6 +308,93 @@ struct tinygemm_kernel_nt<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
}; };
#endif #endif
#if defined(CPU_CAPABILITY_AVX512)
template <typename index_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nt<at::Half, index_t, BLOCK_M, BLOCK_N> {
static inline void apply(
const at::Half* __restrict__ A,
const at::Half* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
float scale,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t K,
int64_t max_tokens) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N;
__m512 va0, va1;
__m512 vb0[COLS], vb1[COLS];
__m512 vc[ROWS * COLS];
__m512 vscale = _mm512_set1_ps(scale);
auto loadc = [&](auto i) { vc[i] = _mm512_setzero_ps(); };
Unroll<ROWS * COLS>{}(loadc);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
__m512i a16 = _mm512_loadu_si512((__m512i const*)(A + row * lda + k));
va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
}
if constexpr (row == 0) {
int64_t b_idx = indices[col];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
__m512i b16 = _mm512_loadu_si512((__m512i const*)(B + b_idx * ldb + k));
vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
}
vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i]));
};
auto compute2 = [&](auto i, int64_t k, __mmask32 mask) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
__m512i a16 = _mm512_maskz_loadu_epi16(mask, (const void*)(A + row * lda + k));
va0 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 0));
va1 = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(a16, 1));
}
if constexpr (row == 0) {
int64_t b_idx = indices[col];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
__m512i b16 = _mm512_maskz_loadu_epi16(mask, (const void*)(B + b_idx * ldb + k));
vb0[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
vb1[col] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
}
vc[i] = _mm512_fmadd_ps(va0, vb0[col], _mm512_fmadd_ps(va1, vb1[col], vc[i]));
};
int64_t k = 0;
for (; k <= K - 32; k += 32) {
Unroll<ROWS * COLS>{}(compute, k);
}
int64_t count = K - k;
if (count > 0) {
__mmask32 mask = (1ULL << count) - 1;
Unroll<ROWS * COLS>{}(compute2, k, mask);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
C[row * ldc + col] = _mm512_reduce_add_ps(_mm512_mul_ps(vc[i], vscale));
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \ #define LAUNCH_TINYGEMM_KERNEL_NT(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \ tinygemm_kernel_nt<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens); A + mb_start * lda, B, C + mb_start * ldc + nb_start, indices + nb_start, scale, lda, ldb, ldc, K, max_tokens);
...@@ -443,6 +530,87 @@ struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> { ...@@ -443,6 +530,87 @@ struct tinygemm_kernel_nn<at::BFloat16, index_t, BLOCK_M, BLOCK_N> {
}; };
#endif #endif
#if defined(CPU_CAPABILITY_AVX512)
template <typename index_t, int BLOCK_M, int BLOCK_N>
struct tinygemm_kernel_nn<at::Half, index_t, BLOCK_M, BLOCK_N> {
static inline void apply(
const float* __restrict__ A,
const at::Half* __restrict__ B,
float* __restrict__ C,
const index_t* __restrict__ indices,
const float* __restrict__ scale,
int64_t lda,
int64_t ldb,
int64_t ldc,
int64_t K,
int64_t max_tokens) {
constexpr int ROWS = BLOCK_M;
constexpr int COLS = BLOCK_N / 16;
__m512 va;
__m512 vb[COLS];
__m512 vc[ROWS * COLS];
__m512 vscale;
auto loadc = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Warray-bounds"
if constexpr (col == 0) {
vscale = _mm512_set1_ps(scale[row]);
}
#pragma GCC diagnostic pop
vc[i] = _mm512_loadu_ps(C + row * ldc + col * 16);
vc[i] = _mm512_mul_ps(vc[i], vscale);
};
Unroll<ROWS * COLS>{}(loadc);
auto compute = [&](auto i, int64_t k) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
if constexpr (col == 0) {
va = _mm512_set1_ps(A[row * lda + k]);
}
if constexpr (row == 0) {
if (k + 1 < K) {
int64_t b_idx_prefetch = indices[k + 1];
_mm_prefetch(B + b_idx_prefetch * ldb + col * 16, _MM_HINT_T0);
}
int64_t b_idx = indices[k];
TORCH_CHECK(b_idx < max_tokens, "token index out of scope!");
// for COLS = 2, 4, 6, 8 use 512 bit load
// for COLS = 1, 3, 5, 7 use 256 bit load
if constexpr (COLS % 2 == 0) {
if constexpr (col % 2 == 0) {
__m512i b16 = _mm512_loadu_si512(reinterpret_cast<const __m512i*>(B + b_idx * ldb + col * 16));
vb[col + 0] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 0));
vb[col + 1] = CVT_FP16_TO_FP32(_mm512_extracti32x8_epi32(b16, 1));
}
} else {
__m256i b16 = _mm256_loadu_si256(reinterpret_cast<const __m256i*>(B + b_idx * ldb + col * 16));
vb[col] = CVT_FP16_TO_FP32(b16);
}
}
vc[i] = _mm512_fmadd_ps(va, vb[col], vc[i]);
};
for (int64_t k = 0; k < K; ++k) {
Unroll<ROWS * COLS>{}(compute, k);
}
auto storec = [&](auto i) {
constexpr int row = i / COLS;
constexpr int col = i % COLS;
_mm512_storeu_ps(C + row * ldc + col * 16, vc[i]);
};
Unroll<ROWS * COLS>{}(storec);
}
};
#endif
#define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \ #define LAUNCH_TINYGEMM_KERNEL_NN(MB_SIZE, NB_SIZE) \
tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \ tinygemm_kernel_nn<scalar_t, index_t, MB_SIZE, NB_SIZE>::apply( \
A + mb_start * lda, \ A + mb_start * lda, \
...@@ -512,9 +680,10 @@ void index_gemm_kernel_nt( ...@@ -512,9 +680,10 @@ void index_gemm_kernel_nt(
return; return;
} }
// pattern: 1-6-24 // default pattern: 1-6-24
// FP16 pattern: 2-8-16
constexpr int64_t BLOCK_M = 4; constexpr int64_t BLOCK_M = 4;
constexpr int64_t BLOCK_N = 6; constexpr int64_t BLOCK_N = std::is_same_v<scalar_t, at::Half> ? 4 : 6;
const int64_t MB = div_up(M, BLOCK_M); const int64_t MB = div_up(M, BLOCK_M);
const int64_t NB = div_up(N, BLOCK_N); const int64_t NB = div_up(N, BLOCK_N);
......
...@@ -47,7 +47,7 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize ...@@ -47,7 +47,7 @@ convert_from_float_ext<at::BFloat16>(const Vectorized<float>& a, const Vectorize
#define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16)) #define CVT_BF16_TO_FP32(a) _mm512_castsi512_ps(_mm512_slli_epi32(_mm512_cvtepu16_epi32(a), 16))
#define CVT_FP16_TO_FP32(a) _mm512_cvtps_ph(a, (_MM_FROUND_TO_NEAREST_INT | _MM_FROUND_NO_EXC)) #define CVT_FP16_TO_FP32(a) _mm512_cvtph_ps(a)
// this doesn't handle NaN. // this doesn't handle NaN.
inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) { inline __m512bh cvt_e4m3_bf16_intrinsic_no_nan(__m256i fp8_vec) {
......
...@@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase): ...@@ -59,8 +59,7 @@ class TestDecodeAttention(CustomTestCase):
return output return output
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device): def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, dtype, device):
dtype = torch.bfloat16
# This represents the number of tokens already in the sequence # This represents the number of tokens already in the sequence
seq_len = 1024 seq_len = 1024
total_tokens = B * seq_len total_tokens = B * seq_len
...@@ -158,8 +157,9 @@ class TestDecodeAttention(CustomTestCase): ...@@ -158,8 +157,9 @@ class TestDecodeAttention(CustomTestCase):
] ]
for B, H_Q, H_KV, D, D_V in configs: for B, H_Q, H_KV, D, D_V in configs:
for dtype in [torch.bfloat16, torch.float16]:
self._test_grouped_decode_attention_once( self._test_grouped_decode_attention_once(
B, H_Q, H_KV, D, D_V, device=device B, H_Q, H_KV, D, D_V, dtype=dtype, device=device
) )
def test_grouped_decode_attention(self): def test_grouped_decode_attention(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment