Unverified Commit abe9f7bd authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Support actual seqlen in flash-attention2 (#418)

* support actual seqlen

* fix lint

* update variable types

* lint

* update type

* fix lint

---------
parent 3a7880a8
......@@ -1422,8 +1422,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
int offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh
+ co * QK_ELTS_IN_16B + ci;
size_t offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;
if (!QUANT_POLICY) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
......
......@@ -215,6 +215,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
layer_offset,
attention_mask,
cu_seqlens,
input_tensors->at("context_lengths").getPtr<int>(),
batch_size,
max_q_len,
max_k_len,
......@@ -258,6 +259,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
size_t cache_layer_offset,
T* attention_mask,
int* cu_seqlens,
int* context_lengths,
int batch_size,
int max_q_len,
int max_k_len,
......@@ -274,13 +276,13 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
int(size_per_head_),
int(max_seq_len * size_per_head_),
false,
int(cache_layer_offset),
cache_layer_offset,
key_cache_ptrs};
Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_),
int(size_per_head_),
int(max_seq_len * size_per_head_),
false,
int(cache_layer_offset),
cache_layer_offset,
val_cache_ptrs};
Layout layout_o{
int(local_head_num_ * max_q_len * size_per_head_),
......@@ -298,6 +300,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
qk_buf_float_,
cu_seqlens,
nullptr,
nullptr,
context_lengths,
group_size,
layout_q,
layout_k,
......
......@@ -72,6 +72,7 @@ public:
size_t cache_layer_offset,
T* attention_mask,
int* cu_seqlens,
int* context_lengths,
int batch_size,
int max_q_len,
int max_k_len,
......
......@@ -130,7 +130,7 @@ static inline void fusedQKV_masked_attention_dispatch(const T* qkv_buf,
params.hidden_size_per_head = size_per_head;
params.rotary_embedding_dim = rotary_embedding_dim;
params.rotary_embedding_base = rotary_embedding_base;
params.rotary_embedding_base = rotary_embedding_base;
params.max_position_embeddings = max_position_embeddings;
params.use_dynamic_ntk = use_dynamic_ntk;
params.use_logn_attn = use_logn_attn;
......
......@@ -93,7 +93,8 @@ LlamaV2<T>::LlamaV2(size_t head_num,
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);
vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
vocab_size_padded_ =
(vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) {
......@@ -171,7 +172,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
vocab_size_padded_,
0, // end_id, deprecated
0, // end_id, deprecated
stream_,
cublas_wrapper_,
allocator_,
......
......@@ -95,8 +95,10 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);
loadWeightFromBin(
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type);
loadWeightFromBin((T*)post_decoder_embedding_kernel,
{hidden_units_ * vocab_size_padded_},
dir_path + "output.weight",
model_file_type);
for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
......
......@@ -15,10 +15,14 @@ struct BlockInfo {
__device__ BlockInfo(const Params& params, const int bidb):
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q),
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k)
actual_seqlen_q(params.actual_seqlen_q == nullptr ?
(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
params.cu_seqlens_q[bidb + 1] - sum_s_q) :
params.actual_seqlen_q[bidb]),
actual_seqlen_k(params.actual_seqlen_k == nullptr ?
(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k) :
params.actual_seqlen_k[bidb])
{
}
......
......@@ -16,7 +16,7 @@ constexpr int D_DIM = 2;
////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params {
using index_t = uint32_t;
using index_t = size_t;
// The QKV matrices.
void* __restrict__ q_ptr;
void* __restrict__ k_ptr;
......@@ -25,8 +25,8 @@ struct Qkv_params {
// batched ptr inputs.
void** __restrict__ k_batched_ptr = nullptr;
void** __restrict__ v_batched_ptr = nullptr;
int k_batched_offset = 0;
int v_batched_offset = 0;
size_t k_batched_offset = 0;
size_t v_batched_offset = 0;
// The stride between rows of the Q, K and V matrices.
index_t q_batch_stride;
......@@ -72,6 +72,10 @@ struct Flash_fwd_params: public Qkv_params {
int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k;
// array of length b with actual length of each sequence
int* __restrict__ actual_seqlen_q;
int* __restrict__ actual_seqlen_k;
void* __restrict__ blockmask;
bool is_bf16;
......
......@@ -121,6 +121,9 @@ public:
fwd_params.cu_seqlens_q = params.cu_seqlens_q;
fwd_params.cu_seqlens_k = params.cu_seqlens_k;
fwd_params.actual_seqlen_q = params.actual_seqlen_q;
fwd_params.actual_seqlen_k = params.actual_seqlen_k;
fwd_params.blockmask = reinterpret_cast<void*>(params.mask);
fwd_params.is_bf16 = false;
......
......@@ -70,10 +70,10 @@ struct LlamaAttentionKernel:
scalar_t** v_batch_seqs_ptr = nullptr;
output_t** o_batch_seqs_ptr = nullptr;
int q_batch_seqs_offset = 0;
int k_batch_seqs_offset = 0;
int v_batch_seqs_offset = 0;
int o_batch_seqs_offset = 0;
size_t q_batch_seqs_offset = 0;
size_t k_batch_seqs_offset = 0;
size_t v_batch_seqs_offset = 0;
size_t o_batch_seqs_offset = 0;
int32_t group_size = 1;
......@@ -81,7 +81,7 @@ struct LlamaAttentionKernel:
template<typename ptr_t>
CUTLASS_DEVICE void
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, int batch_seq_offset, int batch_id, int strideB)
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, size_t batch_seq_offset, int batch_id, int strideB)
{
if (batch_seq_ptr != nullptr)
data_ptr = batch_seq_ptr[batch_id] + batch_seq_offset;
......
......@@ -80,12 +80,12 @@ void invokeMyCopyInt(int* dst, const int* src, size_t count, cudaStream_t st);
template<typename T>
struct BaseAttentionLayout {
int stride_batch;
int stride_seq;
int stride_head;
bool use_seqlens = false;
int batch_seqs_offset = 0;
T** batch_seqs = nullptr;
int stride_batch;
int stride_seq;
int stride_head;
bool use_seqlens = false;
size_t batch_seqs_offset = 0;
T** batch_seqs = nullptr;
};
template<typename T>
......@@ -95,10 +95,12 @@ struct BaseAttentionParams {
T* key;
T* val;
T* mask;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
size_t group_size = 1;
float* out_accum = nullptr;
int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr;
int* actual_seqlen_q = nullptr;
int* actual_seqlen_k = nullptr;
size_t group_size = 1;
BaseAttentionLayout<T> layout_q;
BaseAttentionLayout<T> layout_k;
BaseAttentionLayout<T> layout_v;
......
......@@ -278,6 +278,8 @@ int main(int argc, const char* argv[])
// auto* input_lengths = (int*)allocator.malloc(sizeof(int) * batch_size, false);
thrust::device_vector<int> input_lengths(batch_size);
thrust::host_vector<int> input_lengths_host(batch_size);
thrust::device_vector<int> kv_lengths(batch_size);
thrust::host_vector<int> kv_lengths_host(batch_size);
cudaRandomUniform<scalar_t>(query_ptr, batch_size * num_heads * seq_len * size_per_head);
cudaRandomUniform<scalar_t>(key_ptr, batch_size * num_heads * key_len * size_per_head);
......@@ -285,13 +287,12 @@ int main(int argc, const char* argv[])
cudaRandomUniform<scalar_t>(mask_ptr, batch_size * seq_len * key_len);
// create random length for batch
std::uniform_int_distribution<int> dist{seq_len / 2, seq_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(input_lengths_host), end(input_lengths_host), gen);
// for(int batch_id=0;batch_id<batch_size;++batch_id){
// input_lengths_host[batch_id] = seq_len;
// }
thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin());
{
std::uniform_int_distribution<int> dist{seq_len / 2, seq_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(input_lengths_host), end(input_lengths_host), gen);
thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin());
}
size_t h_token_num = 0;
size_t* h_pinned_token_num;
auto input_lengths_ptr = thrust::raw_pointer_cast(input_lengths.data());
......@@ -306,10 +307,16 @@ int main(int argc, const char* argv[])
stream);
cudaFreeHost((void*)h_pinned_token_num);
int* k_lens = (int*)allocator.malloc(batch_size * sizeof(int));
deviceFill(k_lens, batch_size, key_len, stream);
{
std::uniform_int_distribution<int> dist{seq_len, key_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(kv_lengths_host), end(kv_lengths_host), gen);
thrust::copy(kv_lengths_host.begin(), kv_lengths_host.end(), kv_lengths.begin());
}
auto kv_lengths_ptr = thrust::raw_pointer_cast(kv_lengths.data());
// deviceFill(kv_lengths_ptr, batch_size, key_len, stream);
invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, k_lens, seq_len, key_len, batch_size, stream);
invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, kv_lengths_ptr, seq_len, key_len, batch_size, stream);
// deviceFill(mask_ptr, batch_size*key_len*seq_len, scalar_t(1), stream);
// compute gt
......@@ -356,6 +363,8 @@ int main(int argc, const char* argv[])
accum_buf_ptr,
cu_seqlens_ptr,
nullptr,
nullptr,
kv_lengths_ptr,
1,
layout_q,
layout_k,
......@@ -367,10 +376,10 @@ int main(int argc, const char* argv[])
int num_rows = 8;
// printf("query:\n");
// printMatrix(query_ptr, num_rows, 8, size_per_head, true);
printf("expect:\n");
printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true);
printf("actual:\n");
printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true);
// printf("expect:\n");
// printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true);
// printf("actual:\n");
// printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true);
checkResult(
"all close:", actual_out_ptr, expect_out_ptr, batch_size * num_heads * seq_len * size_per_head, true, true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment