Unverified Commit b8354dae authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

Disable attention mask when it is not needed (#813)

* disable attention mask when not needed

* fix for sm<80 and float data type
parent d5a89465
...@@ -894,6 +894,8 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i ...@@ -894,6 +894,8 @@ LlamaBatch<T>::LlamaBatch(const EngineParams& params, int cache_block_seq_len, i
session_len_ = max_session_len; session_len_ = max_session_len;
} }
FT_CHECK(max_context_token_num_ >= session_len_);
for (auto& s : states_) { for (auto& s : states_) {
s.requests.resize(max_batch_size_); s.requests.resize(max_batch_size_);
s.sequences.resize(max_batch_size_); s.sequences.resize(max_batch_size_);
......
...@@ -15,8 +15,14 @@ void UnifiedDecoder<T>::allocateBuffer(size_t num_token, size_t pf_batch_size, s ...@@ -15,8 +15,14 @@ void UnifiedDecoder<T>::allocateBuffer(size_t num_token, size_t pf_batch_size, s
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
if (pf_batch_size) { if (pf_batch_size) {
attention_mask_ = if (need_causal_mask_) {
(T*)allocator_->reMalloc(attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false); attention_mask_ = (T*)allocator_->reMalloc(
attention_mask_, sizeof(T) * pf_batch_size * pf_max_q_len * pf_max_k_len, false);
}
else {
// just to avoid nullptr
attention_mask_ = (T*)allocator_->reMalloc(attention_mask_, sizeof(T), false);
}
padding_offset_ = padding_offset_ =
(int*)allocator_->reMalloc(padding_offset_, sizeof(int) * pf_batch_size * pf_max_q_len, false); (int*)allocator_->reMalloc(padding_offset_, sizeof(int) * pf_batch_size * pf_max_q_len, false);
cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (pf_batch_size + 1), false); cu_seqlens_ = (int*)allocator_->reMalloc(cu_seqlens_, sizeof(int) * (pf_batch_size + 1), false);
...@@ -162,14 +168,16 @@ void UnifiedDecoder<T>::forward(TensorMap* outputs, const TensorMap* inputs, con ...@@ -162,14 +168,16 @@ void UnifiedDecoder<T>::forward(TensorMap* outputs, const TensorMap* inputs, con
FT_CHECK(tmp_token_num == token_num - dc_batch_size); FT_CHECK(tmp_token_num == token_num - dc_batch_size);
invokeCreateCausalMasks(attention_mask_, if (need_causal_mask_) {
input_length + pf_offset, invokeCreateCausalMasks(attention_mask_,
context_length + pf_offset, input_length + pf_offset,
pf_max_q_len, context_length + pf_offset,
pf_max_k_len, pf_max_q_len,
pf_batch_size, pf_max_k_len,
stream_); pf_batch_size,
sync_check_cuda_error(); stream_);
sync_check_cuda_error();
}
} }
///////////////////////////////////////////// /////////////////////////////////////////////
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "src/turbomind/models/llama/llama_params.h" #include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/unified_attention_layer.h" #include "src/turbomind/models/llama/unified_attention_layer.h"
#include "src/turbomind/utils/cublasMMWrapper.h" #include "src/turbomind/utils/cublasMMWrapper.h"
#include "src/turbomind/utils/cuda_utils.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
namespace turbomind { namespace turbomind {
...@@ -46,6 +47,8 @@ protected: ...@@ -46,6 +47,8 @@ protected:
const DataType dtype_; const DataType dtype_;
bool need_causal_mask_{false};
using WeightType = LlamaDecoderLayerWeight<T>; using WeightType = LlamaDecoderLayerWeight<T>;
void forwardSelfAttn(T* attn_io, void forwardSelfAttn(T* attn_io,
...@@ -88,6 +91,14 @@ public: ...@@ -88,6 +91,14 @@ public:
tensor_para_(tensor_para), tensor_para_(tensor_para),
dtype_(getTensorType<T>()) dtype_(getTensorType<T>())
{ {
#ifdef _MSC_VER
// Both unfused MHA and flash attention 1 need causal mask
need_causal_mask_ = true;
#endif
// attention mask is not used for FA-1 (which requires sm80+ and half/bf16 data type)
if (!use_fmha || (getSMVersion() < 80 || sizeof(T) != 2)) {
need_causal_mask_ = true;
}
initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy); initialize(attn_params, kv_head_num, use_fmha, cache_block_seq_len, quant_policy);
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment