Commit a157cc8c authored by Tri Dao's avatar Tri Dao
Browse files

[FT] Implement MQA/GQA

parent 75e334d4
......@@ -69,7 +69,9 @@ struct Multihead_attention_params_base {
const int* cache_indir = nullptr;
// Stride to handle the case when KQV is a single buffer
int stride = 0;
int stride_q = 0;
int stride_k = 0;
int stride_v = 0;
// The batch size.
int batch_size = 0;
......@@ -79,6 +81,8 @@ struct Multihead_attention_params_base {
int memory_max_len = 0;
// The number of heads (H).
int num_heads = 0;
int num_heads_kv = 0;
int num_heads_q_kv_ratio = 0;
// The hidden dimension per head (Dh).
int hidden_size_per_head = 0;
// The per-head latent space reserved for rotary embeddings.
......
......@@ -943,10 +943,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
// The head.
// const int hi = blockIdx.x;
const int hi = params.nnz_head_idx == nullptr ? blockIdx.x : params.nnz_head_idx[blockIdx.x];
const int hi_kv = hi / params.num_heads_q_kv_ratio;
// Combine the batch and the head indices.
const int bhi = bi * params.num_heads + hi;
const int bhi_kv = bi * params.num_heads_kv + hi_kv;
// Combine the "beam-aware" batch idx and the head indices.
const int bbhi = bbi * params.beam_width * params.num_heads + hi;
const int bbhi = bbi * params.beam_width * params.num_heads_kv + hi_kv;
// The thread in the block.
const int tidx = threadIdx.x;
......@@ -957,7 +959,9 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
float qk = 0.0F;
int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
int q_base_offset = (params.stride_q == 0) ? bhi * Dh : bi * params.stride_q + hi * Dh;
int k_base_offset = (params.stride_k == 0) ? bhi_kv * Dh : bi * params.stride_k + hi_kv * Dh;
int v_base_offset = (params.stride_v == 0) ? bhi_kv * Dh : bi * params.stride_v + hi_kv * Dh;
const size_t bi_seq_len_offset = bi * params.memory_max_len;
......@@ -973,9 +977,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
const bool is_masked = tidx >= QK_VECS_PER_WARP;
// The offset in the Q and K buffer also accounts for the batch.
int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
int q_offset = q_base_offset + tidx * QK_VEC_SIZE;
int k_offset = k_base_offset + tidx * QK_VEC_SIZE;
// The offset in the bias buffer.
int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE;
const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr;
const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0;
......@@ -989,12 +995,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
const auto q_scaling = params.qkv_scale_out[0];
const auto q_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[qk_offset]);
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.q)[q_offset]);
convert_from_float(q, mul<Packed_Float_t, float>(q_scaling, float_from_int8(q_quant)));
}
else {
q = *reinterpret_cast<const Qk_vec*>(&params.q[qk_offset]);
q = *reinterpret_cast<const Qk_vec*>(&params.q[q_offset]);
}
}
......@@ -1007,7 +1013,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
// params.timestep*QK_ELTS_IN_16B +
tlength * QK_ELTS_IN_16B + ci;
k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ?
......@@ -1021,12 +1027,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
using Packed_Float_t = typename packed_type<float, num_elems<Qk_vec>::value>::type;
const auto k_scaling = params.qkv_scale_out[1];
const auto k_quant =
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[qk_offset]);
*reinterpret_cast<const Packed_Int8_t*>(&reinterpret_cast<const int8_t*>(params.k)[k_offset]);
convert_from_float(k, mul<Packed_Float_t, float>(k_scaling, float_from_int8(k_quant)));
}
else {
k = *reinterpret_cast<const Qk_vec*>(&params.k[qk_offset]);
k = *reinterpret_cast<const Qk_vec*>(&params.k[k_offset]);
}
}
}
......@@ -1035,14 +1041,14 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
Qk_vec q_bias;
zero(q_bias);
q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ?
*reinterpret_cast<const Qk_vec*>(&params.q_bias[qk_bias_offset]) :
*reinterpret_cast<const Qk_vec*>(&params.q_bias[q_bias_offset]) :
q_bias;
Qk_vec k_bias;
zero(k_bias);
if (handle_kv) {
k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ?
*reinterpret_cast<const Qk_vec*>(&params.k_bias[qk_bias_offset]) :
*reinterpret_cast<const Qk_vec*>(&params.k_bias[k_bias_offset]) :
k_bias;
}
......@@ -1172,11 +1178,11 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int offset = bhi * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B +
// params.timestep*QK_ELTS_IN_16B +
tlength_circ * QK_ELTS_IN_16B + ci;
if (handle_kv) {
if (handle_kv && hi % params.num_heads_q_kv_ratio == 0) {
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
*reinterpret_cast<Qk_vec*>(&params.k_cache[offset]) = k;
......@@ -1263,7 +1269,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;
// The base pointer for the key in the cache buffer.
T* k_cache = &params.k_cache[bhi * params.memory_max_len * Dh + ki];
T* k_cache = &params.k_cache[bhi_kv * params.memory_max_len * Dh + ki];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T* k_cache_batch = &params.k_cache[bbhi * params.memory_max_len * Dh + ki];
......@@ -1427,7 +1433,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE;
// The base pointer for the value in the cache buffer.
T* v_cache = &params.v_cache[bhi * params.memory_max_len * Dh + vi];
T* v_cache = &params.v_cache[bhi_kv * params.memory_max_len * Dh + vi];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T* v_cache_batch = &params.v_cache[bbhi * params.memory_max_len * Dh + vi];
......@@ -1443,7 +1449,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
if (vo == tlength % V_PER_ITER) {
// Trigger the loads from the V bias buffer.
if (params.v_bias != nullptr) {
v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi * Dh + vi]);
v_bias = *reinterpret_cast<const V_vec*>(&params.v_bias[hi_kv * Dh + vi]);
}
if (DO_CROSS_ATTENTION) {
*reinterpret_cast<V_vec*>(&bias_smem[vi]) = v_bias;
......@@ -1510,7 +1516,7 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
}
else {
// Trigger the loads from the V buffer.
const auto v_offset = qkv_base_offset + vi;
const auto v_offset = v_base_offset + vi;
if (params.int8_mode == 2) {
using Packed_Int8_t = typename packed_type<int8_t, num_elems<V_vec>::value>::type;
using Packed_Float_t = typename packed_type<float, num_elems<V_vec>::value>::type;
......@@ -1539,8 +1545,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T,
}
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
if (hi % params.num_heads_q_kv_ratio == 0) {
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*reinterpret_cast<V_vec*>(&v_cache[tlength_circ * Dh]) = v;
}
}
// Initialize the output value with the current timestep.
......
......@@ -50,13 +50,16 @@ template <typename T>
void set_params(Masked_multihead_attention_params<T> &params,
const size_t batch_size,
const size_t nheads,
const size_t nheads_kv,
const size_t memory_max_seqlen,
const size_t headdim,
const int timestep,
const int rotary_embedding_dim,
const float rotary_base,
const bool neox_rotary_style,
const int qkv_batch_stride,
const int q_batch_stride,
const int k_batch_stride,
const int v_batch_stride,
const int nnz_heads,
T *q_ptr,
T *k_ptr,
......@@ -80,11 +83,15 @@ void set_params(Masked_multihead_attention_params<T> &params,
params.v_cache = v_cache_ptr;
params.out = out_ptr;
params.cache_indir = nullptr;
params.stride = qkv_batch_stride;
params.stride_q = q_batch_stride;
params.stride_k = k_batch_stride;
params.stride_v = v_batch_stride;
params.batch_size = batch_size;
params.beam_width = 1;
params.memory_max_len = memory_max_seqlen;
params.num_heads = nheads;
params.num_heads_kv = nheads_kv;
params.num_heads_q_kv_ratio = nheads / nheads_kv;
params.nnz_heads = nnz_heads;
params.hidden_size_per_head = headdim;
params.rotary_embedding_dim = rotary_embedding_dim;
......@@ -124,23 +131,23 @@ torch::Tensor single_query_attention(const torch::Tensor q,
const bool neox_rotary_style=true) {
CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache);
int batch_size = v_cache.size(0);
int nheads = v_cache.size(1);
int nheads = q.size(1);
int nheads_kv = v_cache.size(1);
int memory_max_seqlen = v_cache.size(2);
int headdim = v_cache.size(3);
auto input_type = q.scalar_type();
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
CHECK_SHAPE(q, batch_size, nheads, headdim);
CHECK_SHAPE(k, batch_size, nheads, headdim);
CHECK_SHAPE(v, batch_size, nheads, headdim);
CHECK_SHAPE(v_cache, batch_size, nheads, memory_max_seqlen, headdim);
CHECK_SHAPE(k, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v, batch_size, nheads_kv, headdim);
CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8;
CHECK_SHAPE(k_cache, batch_size, nheads, headdim / packsize, memory_max_seqlen, packsize);
CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize);
TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim);
TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim);
TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim);
TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache);
TORCH_CHECK(q.scalar_type() == input_type);
......@@ -191,8 +198,9 @@ torch::Tensor single_query_attention(const torch::Tensor q,
DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] {
using DataType = typename SATypeConverter<scalar_t>::Type;
Masked_multihead_attention_params<DataType> params;
set_params(params, batch_size, nheads, memory_max_seqlen, headdim, timestep,
rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0),
set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, timestep,
rotary_embedding_dim, rotary_base, neox_rotary_style,
q.stride(0), k.stride(0), v.stride(0),
nnz_head_idx_.has_value() ? nnz_head_idx_.value().size(0) : 0,
reinterpret_cast<DataType*>(q.data_ptr()),
reinterpret_cast<DataType*>(k.data_ptr()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment