Commit a157cc8c authored by Tri Dao's avatar Tri Dao
Browse files

[FT] Implement MQA/GQA

parent 75e334d4
...@@ -69,7 +69,9 @@ struct Multihead_attention_params_base { ...@@ -69,7 +69,9 @@ struct Multihead_attention_params_base {
const int* cache_indir = nullptr; const int* cache_indir = nullptr;
// Stride to handle the case when KQV is a single buffer // Stride to handle the case when KQV is a single buffer
int stride = 0; int stride_q = 0;
int stride_k = 0;
int stride_v = 0;
// The batch size. // The batch size.
int batch_size = 0; int batch_size = 0;
...@@ -79,6 +81,8 @@ struct Multihead_attention_params_base { ...@@ -79,6 +81,8 @@ struct Multihead_attention_params_base {
int memory_max_len = 0; int memory_max_len = 0;
// The number of heads (H). // The number of heads (H).
int num_heads = 0; int num_heads = 0;
int num_heads_kv = 0;
int num_heads_q_kv_ratio = 0;
// The hidden dimension per head (Dh). // The hidden dimension per head (Dh).
int hidden_size_per_head = 0; int hidden_size_per_head = 0;
// The per-head latent space reserved for rotary embeddings. // The per-head latent space reserved for rotary embeddings.
......
...@@ -943,10 +943,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -943,10 +943,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The head. // The head.
// const int hi = blockIdx.x; // const int hi = blockIdx.x;
const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x]; const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
const int hi_kv = hi / params.num_heads_q_kv_ratio;
// Combine the batch and the head indices. // Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi; const int bhi = bi * params.num_heads + hi;
const int bhi_kv = bi * params.num_heads_kv + hi_kv;
// Combine the "beam-aware" batch idx and the head indices. // Combine the "beam-aware" batch idx and the head indices.
const int bbhi = bbi * params.beam_width * params.num_heads + hi; const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv;
// The thread in the block. // The thread in the block.
const int tidx = threadIdx.x; const int tidx = threadIdx.x;
...@@ -957,7 +959,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -957,7 +959,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
float qk = 0.0F; float qk = 0.0F;
int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh;
int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh;
int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh;
const size_t bi_seq_len_offset = bi * params.memory_max_len; const size_t bi_seq_len_offset = bi * params.memory_max_len;
...@@ -973,9 +977,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -973,9 +977,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const bool is_masked = tidx >= QK_VECS_PER_WARP; const bool is_masked = tidx >= QK_VECS_PER_WARP;
// The offset in the Q and K buffer also accounts for the batch. // The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
// The offset in the bias buffer. // The offset in the bias buffer.
int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
...@@ -989,12 +995,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -989,12 +995,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type; using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
const auto q_scaling = params.qkv_scale_out[0]; const auto q_scaling = params.qkv_scale_out[0];
const auto q_quant = const auto q_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]); *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant))); convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
} }
else { else {
q = *reinterpret_cast<const Qk_vec*>(&params.q[qk_offset]); q = *reinterpret_cast<const Qk_vec*>(&params.q[q_offset]);
} }
} }
...@@ -1007,7 +1013,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1007,7 +1013,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
// params.timestep*QK_ELTS_IN_16B + // params.timestep*QK_ELTS_IN_16B +
tlength * QK_ELTS_IN_16B + ci; tlength * QK_ELTS_IN_16B + ci;
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
...@@ -1021,12 +1027,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1021,12 +1027,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type; using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
const auto k_scaling = params.qkv_scale_out[1]; const auto k_scaling = params.qkv_scale_out[1];
const auto k_quant = const auto k_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]); *reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant))); convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
} }
else { else {
k = *reinterpret_cast<const Qk_vec*>(&params.k[qk_offset]); k = *reinterpret_cast<const Qk_vec*>(&params.k[k_offset]);
} }
} }
} }
...@@ -1035,14 +1041,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1035,14 +1041,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
Qk_vec q_bias; Qk_vec q_bias;
zero(q_bias); zero(q_bias);
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
*reinterpret_cast<const Qk_vec*>(&params.q_bias[qk_bias_offset]) : *reinterpret_cast<const Qk_vec*>(&params.q_bias[q_bias_offset]) :
q_bias; q_bias;
Qk_vec k_bias; Qk_vec k_bias;
zero(k_bias); zero(k_bias);
if (handle_kv) { if (handle_kv) {
k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
*reinterpret_cast<const Qk_vec*>(&params.k_bias[qk_bias_offset]) : *reinterpret_cast<const Qk_vec*>(&params.k_bias[k_bias_offset]) :
k_bias; k_bias;
} }
...@@ -1172,11 +1178,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1172,11 +1178,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
// params.timestep*QK_ELTS_IN_16B + // params.timestep*QK_ELTS_IN_16B +
tlength_circ * QK_ELTS_IN_16B + ci; tlength_circ * QK_ELTS_IN_16B + ci;
if (handle_kv) { if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) {
// Trigger the stores to global memory. // Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k; *reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
...@@ -1263,7 +1269,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1263,7 +1269,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
// The base pointer for the key in the cache buffer. // The base pointer for the key in the cache buffer.
T* k_cache = &params.k_cache[bhi * params.memory_max_len * Dh + ki]; T* k_cache = &params.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki]; T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
...@@ -1427,7 +1433,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1427,7 +1433,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
// The base pointer for the value in the cache buffer. // The base pointer for the value in the cache buffer.
T* v_cache = &params.v_cache[bhi * params.memory_max_len * Dh + vi]; T* v_cache = &params.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer // Base pointer for the beam's batch, before offsetting with indirection buffer
T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi]; T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
...@@ -1443,7 +1449,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1443,7 +1449,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if (vo == tlength % V_PER_ITER) { if (vo == tlength % V_PER_ITER) {
// Trigger the loads from the V bias buffer. // Trigger the loads from the V bias buffer.
if (params.v_bias != nullptr) { if (params.v_bias != nullptr) {
v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi * Dh + vi]); v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi_kv * Dh + vi]);
} }
if (DO_CROSS_ATTENTION) { if (DO_CROSS_ATTENTION) {
*reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias; *reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
...@@ -1510,7 +1516,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1510,7 +1516,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
} }
else { else {
// Trigger the loads from the V buffer. // Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi; const auto v_offset = v_base_offset + vi;
if (params.int8_mode == 2) { if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type; using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type; using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
...@@ -1539,9 +1545,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T, ...@@ -1539,9 +1545,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
} }
// Store the values with bias back to global memory in the cache for V. // Store the values with bias back to global memory in the cache for V.
if (hi % params.num_heads_q_kv_ratio == 0) {
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v; //*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v; *reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
} }
}
// Initialize the output value with the current timestep. // Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) #if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
......
...@@ -50,13 +50,16 @@ template <typename T> ...@@ -50,13 +50,16 @@ template <typename T>
void set_params(Masked_multihead_attention_params<T> &params, void set_params(Masked_multihead_attention_params<T> &params,
const size_t batch_size, const size_t batch_size,
const size_t nheads, const size_t nheads,
const size_t nheads_kv,
const size_t memory_max_seqlen, const size_t memory_max_seqlen,
const size_t headdim, const size_t headdim,
const int timestep, const int timestep,
const int rotary_embedding_dim, const int rotary_embedding_dim,
const float rotary_base, const float rotary_base,
const bool neox_rotary_style, const bool neox_rotary_style,
const int qkv_batch_stride, const int q_batch_stride,
const int k_batch_stride,
const int v_batch_stride,
const int nnz_heads, const int nnz_heads,
T *q_ptr, T *q_ptr,
T *k_ptr, T *k_ptr,
...@@ -80,11 +83,15 @@ void set_params(Masked_multihead_attention_params<T> &params, ...@@ -80,11 +83,15 @@ void set_params(Masked_multihead_attention_params<T> &params,
params.v_cache = v_cache_ptr; params.v_cache = v_cache_ptr;
params.out = out_ptr; params.out = out_ptr;
params.cache_indir = nullptr; params.cache_indir = nullptr;
params.stride = qkv_batch_stride; params.stride_q = q_batch_stride;
params.stride_k = k_batch_stride;
params.stride_v = v_batch_stride;
params.batch_size = batch_size; params.batch_size = batch_size;
params.beam_width = 1; params.beam_width = 1;
params.memory_max_len = memory_max_seqlen; params.memory_max_len = memory_max_seqlen;
params.num_heads = nheads; params.num_heads = nheads;
params.num_heads_kv = nheads_kv;
params.num_heads_q_kv_ratio = nheads / nheads_kv;
params.nnz_heads = nnz_heads; params.nnz_heads = nnz_heads;
params.hidden_size_per_head = headdim; params.hidden_size_per_head = headdim;
params.rotary_embedding_dim = rotary_embedding_dim; params.rotary_embedding_dim = rotary_embedding_dim;
...@@ -124,23 +131,23 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -124,23 +131,23 @@ torch::Tensor single_query_attention(const torch::Tensor q,
const bool neox_rotary_style=true) { const bool neox_rotary_style=true) {
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
int batch_size = v_cache.size(0); int batch_size = v_cache.size(0);
int nheads = v_cache.size(1); int nheads = q.size(1);
int nheads_kv = v_cache.size(1);
int memory_max_seqlen = v_cache.size(2); int memory_max_seqlen = v_cache.size(2);
int headdim = v_cache.size(3); int headdim = v_cache.size(3);
auto input_type = q.scalar_type(); auto input_type = q.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
CHECK_SHAPE(q, batch_size, nheads, headdim); CHECK_SHAPE(q, batch_size, nheads, headdim);
CHECK_SHAPE(k, batch_size, nheads, headdim); CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v, batch_size, nheads, headdim); CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v_cache, batch_size, nheads, memory_max_seqlen, headdim); CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
CHECK_SHAPE(k_cache, batch_size, nheads, headdim / packsize, memory_max_seqlen, packsize); CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
TORCH_CHECK(q.scalar_type() == input_type); TORCH_CHECK(q.scalar_type() == input_type);
...@@ -191,8 +198,9 @@ torch::Tensor single_query_attention(const torch::Tensor q, ...@@ -191,8 +198,9 @@ torch::Tensor single_query_attention(const torch::Tensor q,
DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
using DataType = typename SATypeConverter<scalar_t>::Type; using DataType = typename SATypeConverter<scalar_t>::Type;
Masked_multihead_attention_params<DataType> params; Masked_multihead_attention_params<DataType> params;
set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep, set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep,
rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), rotary_embedding_dim, rotary_base, neox_rotary_style,
q.stride(0), k.stride(0), v.stride(0),
nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0, nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0,
reinterpret_cast<DataType*>(q.data_ptr()), reinterpret_cast<DataType*>(q.data_ptr()),
reinterpret_cast<DataType*>(k.data_ptr()), reinterpret_cast<DataType*>(k.data_ptr()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment