"docs/vscode:/vscode.git/clone" did not exist on "ca5d4db280d6db56205ae531a457273b9a40ecff"
Unverified Commit abe9f7bd authored by q.yao's avatar q.yao Committed by GitHub
Browse files

[Fix] Support actual seqlen in flash-attention2 (#418)

* support actual seqlen

* fix lint

* update variable types

* lint

* update type

* fix lint

---------
parent 3a7880a8
...@@ -1422,8 +1422,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T> ...@@ -1422,8 +1422,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Trigger the stores to global memory. // Trigger the stores to global memory.
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
int offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh + tlength_circ * Dh size_t offset = params.kv_cache_per_sample_offset + kvhi * params.memory_max_len * Dh
+ co * QK_ELTS_IN_16B + ci; + tlength_circ * Dh + co * QK_ELTS_IN_16B + ci;
if (!QUANT_POLICY) { if (!QUANT_POLICY) {
*reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) = *reinterpret_cast<Qk_vec_m*>(&params.k_cache_per_sample[bi][offset]) =
......
...@@ -215,6 +215,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap* ...@@ -215,6 +215,7 @@ inline void LlamaContextAttentionLayer<T>::forward(TensorMap*
layer_offset, layer_offset,
attention_mask, attention_mask,
cu_seqlens, cu_seqlens,
input_tensors->at("context_lengths").getPtr<int>(),
batch_size, batch_size,
max_q_len, max_q_len,
max_k_len, max_k_len,
...@@ -258,6 +259,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr ...@@ -258,6 +259,7 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
size_t cache_layer_offset, size_t cache_layer_offset,
T* attention_mask, T* attention_mask,
int* cu_seqlens, int* cu_seqlens,
int* context_lengths,
int batch_size, int batch_size,
int max_q_len, int max_q_len,
int max_k_len, int max_k_len,
...@@ -274,13 +276,13 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr ...@@ -274,13 +276,13 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
int(size_per_head_), int(size_per_head_),
int(max_seq_len * size_per_head_), int(max_seq_len * size_per_head_),
false, false,
int(cache_layer_offset), cache_layer_offset,
key_cache_ptrs}; key_cache_ptrs};
Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_), Layout layout_v{int(local_head_num_ * max_seq_len * size_per_head_),
int(size_per_head_), int(size_per_head_),
int(max_seq_len * size_per_head_), int(max_seq_len * size_per_head_),
false, false,
int(cache_layer_offset), cache_layer_offset,
val_cache_ptrs}; val_cache_ptrs};
Layout layout_o{ Layout layout_o{
int(local_head_num_ * max_q_len * size_per_head_), int(local_head_num_ * max_q_len * size_per_head_),
...@@ -298,6 +300,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr ...@@ -298,6 +300,8 @@ void LlamaContextAttentionLayer<T>::fusedMultiHeadAttention(T** key_cache_ptr
qk_buf_float_, qk_buf_float_,
cu_seqlens, cu_seqlens,
nullptr, nullptr,
nullptr,
context_lengths,
group_size, group_size,
layout_q, layout_q,
layout_k, layout_k,
......
...@@ -72,6 +72,7 @@ public: ...@@ -72,6 +72,7 @@ public:
size_t cache_layer_offset, size_t cache_layer_offset,
T* attention_mask, T* attention_mask,
int* cu_seqlens, int* cu_seqlens,
int* context_lengths,
int batch_size, int batch_size,
int max_q_len, int max_q_len,
int max_k_len, int max_k_len,
......
...@@ -93,7 +93,8 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -93,7 +93,8 @@ LlamaV2<T>::LlamaV2(size_t head_num,
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);
vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_; vocab_size_padded_ =
(vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
size_t elem_bits = 0; size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) { if (quant_policy & QuantPolicy::kCacheKVInt8) {
......
...@@ -95,8 +95,10 @@ void LlamaWeight<T>::loadModel(std::string dir_path) ...@@ -95,8 +95,10 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type); loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);
loadWeightFromBin( loadWeightFromBin((T*)post_decoder_embedding_kernel,
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type); {hidden_units_ * vocab_size_padded_},
dir_path + "output.weight",
model_file_type);
for (unsigned layer = 0; layer < num_layer_; ++layer) { for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type); decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
......
...@@ -15,10 +15,14 @@ struct BlockInfo { ...@@ -15,10 +15,14 @@ struct BlockInfo {
__device__ BlockInfo(const Params& params, const int bidb): __device__ BlockInfo(const Params& params, const int bidb):
sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]), sum_s_q(!Varlen || params.cu_seqlens_q == nullptr ? -1 : params.cu_seqlens_q[bidb]),
sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]), sum_s_k(!Varlen || params.cu_seqlens_k == nullptr ? -1 : params.cu_seqlens_k[bidb]),
actual_seqlen_q(!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q : actual_seqlen_q(params.actual_seqlen_q == nullptr ?
params.cu_seqlens_q[bidb + 1] - sum_s_q), (!Varlen || params.cu_seqlens_q == nullptr ? params.seqlen_q :
actual_seqlen_k(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k : params.cu_seqlens_q[bidb + 1] - sum_s_q) :
params.cu_seqlens_k[bidb + 1] - sum_s_k) params.actual_seqlen_q[bidb]),
actual_seqlen_k(params.actual_seqlen_k == nullptr ?
(!Varlen || params.cu_seqlens_k == nullptr ? params.seqlen_k :
params.cu_seqlens_k[bidb + 1] - sum_s_k) :
params.actual_seqlen_k[bidb])
{ {
} }
......
...@@ -16,7 +16,7 @@ constexpr int D_DIM = 2; ...@@ -16,7 +16,7 @@ constexpr int D_DIM = 2;
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
struct Qkv_params { struct Qkv_params {
using index_t = uint32_t; using index_t = size_t;
// The QKV matrices. // The QKV matrices.
void* __restrict__ q_ptr; void* __restrict__ q_ptr;
void* __restrict__ k_ptr; void* __restrict__ k_ptr;
...@@ -25,8 +25,8 @@ struct Qkv_params { ...@@ -25,8 +25,8 @@ struct Qkv_params {
// batched ptr inputs. // batched ptr inputs.
void** __restrict__ k_batched_ptr = nullptr; void** __restrict__ k_batched_ptr = nullptr;
void** __restrict__ v_batched_ptr = nullptr; void** __restrict__ v_batched_ptr = nullptr;
int k_batched_offset = 0; size_t k_batched_offset = 0;
int v_batched_offset = 0; size_t v_batched_offset = 0;
// The stride between rows of the Q, K and V matrices. // The stride between rows of the Q, K and V matrices.
index_t q_batch_stride; index_t q_batch_stride;
...@@ -72,6 +72,10 @@ struct Flash_fwd_params: public Qkv_params { ...@@ -72,6 +72,10 @@ struct Flash_fwd_params: public Qkv_params {
int* __restrict__ cu_seqlens_q; int* __restrict__ cu_seqlens_q;
int* __restrict__ cu_seqlens_k; int* __restrict__ cu_seqlens_k;
// array of length b with actual length of each sequence
int* __restrict__ actual_seqlen_q;
int* __restrict__ actual_seqlen_k;
void* __restrict__ blockmask; void* __restrict__ blockmask;
bool is_bf16; bool is_bf16;
......
...@@ -121,6 +121,9 @@ public: ...@@ -121,6 +121,9 @@ public:
fwd_params.cu_seqlens_q = params.cu_seqlens_q; fwd_params.cu_seqlens_q = params.cu_seqlens_q;
fwd_params.cu_seqlens_k = params.cu_seqlens_k; fwd_params.cu_seqlens_k = params.cu_seqlens_k;
fwd_params.actual_seqlen_q = params.actual_seqlen_q;
fwd_params.actual_seqlen_k = params.actual_seqlen_k;
fwd_params.blockmask = reinterpret_cast<void*>(params.mask); fwd_params.blockmask = reinterpret_cast<void*>(params.mask);
fwd_params.is_bf16 = false; fwd_params.is_bf16 = false;
......
...@@ -70,10 +70,10 @@ struct LlamaAttentionKernel: ...@@ -70,10 +70,10 @@ struct LlamaAttentionKernel:
scalar_t** v_batch_seqs_ptr = nullptr; scalar_t** v_batch_seqs_ptr = nullptr;
output_t** o_batch_seqs_ptr = nullptr; output_t** o_batch_seqs_ptr = nullptr;
int q_batch_seqs_offset = 0; size_t q_batch_seqs_offset = 0;
int k_batch_seqs_offset = 0; size_t k_batch_seqs_offset = 0;
int v_batch_seqs_offset = 0; size_t v_batch_seqs_offset = 0;
int o_batch_seqs_offset = 0; size_t o_batch_seqs_offset = 0;
int32_t group_size = 1; int32_t group_size = 1;
...@@ -81,7 +81,7 @@ struct LlamaAttentionKernel: ...@@ -81,7 +81,7 @@ struct LlamaAttentionKernel:
template<typename ptr_t> template<typename ptr_t>
CUTLASS_DEVICE void CUTLASS_DEVICE void
update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, int batch_seq_offset, int batch_id, int strideB) update_batched_ptr(ptr_t& data_ptr, ptr_t* batch_seq_ptr, size_t batch_seq_offset, int batch_id, int strideB)
{ {
if (batch_seq_ptr != nullptr) if (batch_seq_ptr != nullptr)
data_ptr = batch_seq_ptr[batch_id] + batch_seq_offset; data_ptr = batch_seq_ptr[batch_id] + batch_seq_offset;
......
...@@ -84,7 +84,7 @@ struct BaseAttentionLayout { ...@@ -84,7 +84,7 @@ struct BaseAttentionLayout {
int stride_seq; int stride_seq;
int stride_head; int stride_head;
bool use_seqlens = false; bool use_seqlens = false;
int batch_seqs_offset = 0; size_t batch_seqs_offset = 0;
T** batch_seqs = nullptr; T** batch_seqs = nullptr;
}; };
...@@ -98,6 +98,8 @@ struct BaseAttentionParams { ...@@ -98,6 +98,8 @@ struct BaseAttentionParams {
float* out_accum = nullptr; float* out_accum = nullptr;
int* cu_seqlens_q = nullptr; int* cu_seqlens_q = nullptr;
int* cu_seqlens_k = nullptr; int* cu_seqlens_k = nullptr;
int* actual_seqlen_q = nullptr;
int* actual_seqlen_k = nullptr;
size_t group_size = 1; size_t group_size = 1;
BaseAttentionLayout<T> layout_q; BaseAttentionLayout<T> layout_q;
BaseAttentionLayout<T> layout_k; BaseAttentionLayout<T> layout_k;
......
...@@ -278,6 +278,8 @@ int main(int argc, const char* argv[]) ...@@ -278,6 +278,8 @@ int main(int argc, const char* argv[])
// auto* input_lengths = (int*)allocator.malloc(sizeof(int) * batch_size, false); // auto* input_lengths = (int*)allocator.malloc(sizeof(int) * batch_size, false);
thrust::device_vector<int> input_lengths(batch_size); thrust::device_vector<int> input_lengths(batch_size);
thrust::host_vector<int> input_lengths_host(batch_size); thrust::host_vector<int> input_lengths_host(batch_size);
thrust::device_vector<int> kv_lengths(batch_size);
thrust::host_vector<int> kv_lengths_host(batch_size);
cudaRandomUniform<scalar_t>(query_ptr, batch_size * num_heads * seq_len * size_per_head); cudaRandomUniform<scalar_t>(query_ptr, batch_size * num_heads * seq_len * size_per_head);
cudaRandomUniform<scalar_t>(key_ptr, batch_size * num_heads * key_len * size_per_head); cudaRandomUniform<scalar_t>(key_ptr, batch_size * num_heads * key_len * size_per_head);
...@@ -285,13 +287,12 @@ int main(int argc, const char* argv[]) ...@@ -285,13 +287,12 @@ int main(int argc, const char* argv[])
cudaRandomUniform<scalar_t>(mask_ptr, batch_size * seq_len * key_len); cudaRandomUniform<scalar_t>(mask_ptr, batch_size * seq_len * key_len);
// create random length for batch // create random length for batch
{
std::uniform_int_distribution<int> dist{seq_len / 2, seq_len}; std::uniform_int_distribution<int> dist{seq_len / 2, seq_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); }; auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(input_lengths_host), end(input_lengths_host), gen); std::generate(begin(input_lengths_host), end(input_lengths_host), gen);
// for(int batch_id=0;batch_id<batch_size;++batch_id){
// input_lengths_host[batch_id] = seq_len;
// }
thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin()); thrust::copy(input_lengths_host.begin(), input_lengths_host.end(), input_lengths.begin());
}
size_t h_token_num = 0; size_t h_token_num = 0;
size_t* h_pinned_token_num; size_t* h_pinned_token_num;
auto input_lengths_ptr = thrust::raw_pointer_cast(input_lengths.data()); auto input_lengths_ptr = thrust::raw_pointer_cast(input_lengths.data());
...@@ -306,10 +307,16 @@ int main(int argc, const char* argv[]) ...@@ -306,10 +307,16 @@ int main(int argc, const char* argv[])
stream); stream);
cudaFreeHost((void*)h_pinned_token_num); cudaFreeHost((void*)h_pinned_token_num);
int* k_lens = (int*)allocator.malloc(batch_size * sizeof(int)); {
deviceFill(k_lens, batch_size, key_len, stream); std::uniform_int_distribution<int> dist{seq_len, key_len};
auto gen = [&dist, &mersenne_engine]() { return dist(mersenne_engine); };
std::generate(begin(kv_lengths_host), end(kv_lengths_host), gen);
thrust::copy(kv_lengths_host.begin(), kv_lengths_host.end(), kv_lengths.begin());
}
auto kv_lengths_ptr = thrust::raw_pointer_cast(kv_lengths.data());
// deviceFill(kv_lengths_ptr, batch_size, key_len, stream);
invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, k_lens, seq_len, key_len, batch_size, stream); invokeCreateCausalMasks(mask_ptr, input_lengths_ptr, kv_lengths_ptr, seq_len, key_len, batch_size, stream);
// deviceFill(mask_ptr, batch_size*key_len*seq_len, scalar_t(1), stream); // deviceFill(mask_ptr, batch_size*key_len*seq_len, scalar_t(1), stream);
// compute gt // compute gt
...@@ -356,6 +363,8 @@ int main(int argc, const char* argv[]) ...@@ -356,6 +363,8 @@ int main(int argc, const char* argv[])
accum_buf_ptr, accum_buf_ptr,
cu_seqlens_ptr, cu_seqlens_ptr,
nullptr, nullptr,
nullptr,
kv_lengths_ptr,
1, 1,
layout_q, layout_q,
layout_k, layout_k,
...@@ -367,10 +376,10 @@ int main(int argc, const char* argv[]) ...@@ -367,10 +376,10 @@ int main(int argc, const char* argv[])
int num_rows = 8; int num_rows = 8;
// printf("query:\n"); // printf("query:\n");
// printMatrix(query_ptr, num_rows, 8, size_per_head, true); // printMatrix(query_ptr, num_rows, 8, size_per_head, true);
printf("expect:\n"); // printf("expect:\n");
printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true); // printMatrix(expect_out_ptr, num_rows, 8, size_per_head, true);
printf("actual:\n"); // printf("actual:\n");
printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true); // printMatrix(actual_out_ptr, num_rows, 8, size_per_head, true);
checkResult( checkResult(
"all close:", actual_out_ptr, expect_out_ptr, batch_size * num_heads * seq_len * size_per_head, true, true); "all close:", actual_out_ptr, expect_out_ptr, batch_size * num_heads * seq_len * size_per_head, true, true);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment