Unverified Commit 4a60b45d authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Support Qwen-7B, dynamic NTK scaling and logN scaling in turbomind (#230)

* qwen support

* dynamic ntk & logn attn

* fix ntk & add chat template

* fix ntk scaling & stop words

* fix lint

* add tiktoken to requirements.txt

* fix tokenizer, set model format automatically

* update model.py

* update readme

* fix lint
parent 62b60db7
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "src/turbomind/models/llama/LlamaDenseWeight.h" #include "src/turbomind/models/llama/LlamaDenseWeight.h"
#include "src/turbomind/models/llama/LlamaLinear.h" #include "src/turbomind/models/llama/LlamaLinear.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
...@@ -33,17 +34,16 @@ public: ...@@ -33,17 +34,16 @@ public:
void freeBuffer(); void freeBuffer();
void allocateBuffer(size_t batch_size, int key_len, int max_memory_len); void allocateBuffer(size_t batch_size, int key_len, int max_memory_len);
LlamaDecoderSelfAttentionLayer(size_t head_num, LlamaDecoderSelfAttentionLayer(size_t head_num,
size_t kv_head_num, size_t kv_head_num,
size_t size_per_head, size_t size_per_head,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
bool neox_rotary_style, NcclParam tensor_para,
NcclParam tensor_para, cudaStream_t stream,
cudaStream_t stream, cublasMMWrapper* cublas_wrapper,
cublasMMWrapper* cublas_wrapper, IAllocator* allocator,
IAllocator* allocator, bool is_free_buffer_after_forward,
bool is_free_buffer_after_forward, int quant_policy):
int quant_policy):
head_num_(head_num), head_num_(head_num),
kv_head_num_(kv_head_num), kv_head_num_(kv_head_num),
size_per_head_(size_per_head), size_per_head_(size_per_head),
...@@ -51,8 +51,7 @@ public: ...@@ -51,8 +51,7 @@ public:
local_head_num_(head_num / tensor_para.world_size_), local_head_num_(head_num / tensor_para.world_size_),
local_kv_head_num_(kv_head_num_ / tensor_para.world_size_), local_kv_head_num_(kv_head_num_ / tensor_para.world_size_),
local_hidden_units_(hidden_units_ / tensor_para.world_size_), local_hidden_units_(hidden_units_ / tensor_para.world_size_),
rotary_embedding_dim_(rotary_embedding_dim), params_(attn_params),
neox_rotary_style_(neox_rotary_style),
tensor_para_(tensor_para), tensor_para_(tensor_para),
stream_(stream), stream_(stream),
linear_(cublas_wrapper, stream), linear_(cublas_wrapper, stream),
...@@ -77,11 +76,10 @@ private: ...@@ -77,11 +76,10 @@ private:
const size_t local_head_num_; const size_t local_head_num_;
const size_t local_kv_head_num_; const size_t local_kv_head_num_;
const size_t local_hidden_units_; const size_t local_hidden_units_;
const size_t rotary_embedding_dim_;
const bool is_free_buffer_after_forward_; const bool is_free_buffer_after_forward_;
const int quant_policy_; const int quant_policy_;
const bool neox_rotary_style_; const LlamaAttentionParams& params_;
NcclParam tensor_para_; NcclParam tensor_para_;
...@@ -91,13 +89,6 @@ private: ...@@ -91,13 +89,6 @@ private:
T* qkv_buf_ = nullptr; T* qkv_buf_ = nullptr;
T* context_buf_ = nullptr; T* context_buf_ = nullptr;
// T* weight_buf_ = nullptr;
// T* k_cache_buf_{};
// T* v_cache_buf_{};
// T* tmp_k_cache_buf_{};
// T* tmp_v_cache_buf_{};
// T* tmp_cache_buf_{};
bool is_allocate_buffer_{}; bool is_allocate_buffer_{};
}; };
......
...@@ -91,11 +91,15 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors, ...@@ -91,11 +91,15 @@ void LlamaFfnLayer<T>::forward(TensorMap* output_tensors,
gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn); gating_buf_, ffn_input_data, num_token, weights->fused_gating_intermediate, LlamaLinear<T>::kFusedSiluFfn);
} }
else { else {
// w1(x)
linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating); linear_.forward(gating_buf_, ffn_input_data, num_token, weights->gating);
// w3(x)
linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate); linear_.forward(inter_buf_, ffn_input_data, num_token, weights->intermediate);
// silu(w1(x)) * w3(x)
activation(num_token); activation(num_token);
} }
// w2(x)
linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output); linear_.forward(ffn_output_data, gating_buf_, num_token, weights->output);
POP_RANGE; POP_RANGE;
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "src/turbomind/models/llama/LlamaNcclGuard.h" #include "src/turbomind/models/llama/LlamaNcclGuard.h"
#include "src/turbomind/models/llama/LlamaWeight.h" #include "src/turbomind/models/llama/LlamaWeight.h"
#include "src/turbomind/models/llama/Request.h" #include "src/turbomind/models/llama/Request.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/models/llama/llama_utils.h" #include "src/turbomind/models/llama/llama_utils.h"
#include "src/turbomind/utils/Tensor.h" #include "src/turbomind/utils/Tensor.h"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
...@@ -45,7 +46,7 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -45,7 +46,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t vocab_size, size_t vocab_size,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float norm_eps, float norm_eps,
int max_batch_size, int max_batch_size,
int max_context_token_num, int max_context_token_num,
...@@ -70,7 +71,6 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -70,7 +71,6 @@ LlamaV2<T>::LlamaV2(size_t head_num,
inter_size_(inter_size), inter_size_(inter_size),
num_layer_(num_layer), num_layer_(num_layer),
vocab_size_(vocab_size), vocab_size_(vocab_size),
rotary_embedding_dim_(rotary_embedding_dim),
rmsnorm_eps_(norm_eps), rmsnorm_eps_(norm_eps),
start_id_(start_id), start_id_(start_id),
end_id_(end_id), end_id_(end_id),
...@@ -116,7 +116,7 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -116,7 +116,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
cache_chunk_size, cache_chunk_size,
tensor_para.rank_, tensor_para.rank_,
allocator); allocator);
initialize(kv_head_num, use_context_fmha, quant_policy); initialize(attn_params, kv_head_num, use_context_fmha, quant_policy);
start(); start();
} }
...@@ -131,7 +131,10 @@ LlamaV2<T>::~LlamaV2() ...@@ -131,7 +131,10 @@ LlamaV2<T>::~LlamaV2()
} }
template<typename T> template<typename T>
void LlamaV2<T>::initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy) void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
size_t kv_head_num,
bool use_context_fmha,
int quant_policy)
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
...@@ -140,7 +143,7 @@ void LlamaV2<T>::initialize(size_t kv_head_num, bool use_context_fmha, int quant ...@@ -140,7 +143,7 @@ void LlamaV2<T>::initialize(size_t kv_head_num, bool use_context_fmha, int quant
size_per_head_, size_per_head_,
inter_size_, inter_size_,
num_layer_, num_layer_,
rotary_embedding_dim_, attn_params,
rmsnorm_eps_, rmsnorm_eps_,
tensor_para_, tensor_para_,
stream_, stream_,
...@@ -155,7 +158,7 @@ void LlamaV2<T>::initialize(size_t kv_head_num, bool use_context_fmha, int quant ...@@ -155,7 +158,7 @@ void LlamaV2<T>::initialize(size_t kv_head_num, bool use_context_fmha, int quant
size_per_head_, size_per_head_,
inter_size_, inter_size_,
num_layer_, num_layer_,
rotary_embedding_dim_, attn_params,
rmsnorm_eps_, rmsnorm_eps_,
tensor_para_, tensor_para_,
stream_, stream_,
......
...@@ -54,7 +54,7 @@ public: ...@@ -54,7 +54,7 @@ public:
size_t inter_size, size_t inter_size,
size_t num_layer, size_t num_layer,
size_t vocab_size, size_t vocab_size,
size_t rotary_embedding_dim, const LlamaAttentionParams& attn_params,
float norm_eps, float norm_eps,
int max_batch_size, int max_batch_size,
int max_context_token_num, int max_context_token_num,
...@@ -96,7 +96,8 @@ private: ...@@ -96,7 +96,8 @@ private:
void internalThreadEntry(int device_id); void internalThreadEntry(int device_id);
void initialize(size_t kv_head_num, bool use_context_fmha, int quant_policy); void
initialize(const LlamaAttentionParams& attn_params, size_t kv_head_num, bool use_context_fmha, int quant_policy);
void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step); void embeddingLookup(T* embeddings, const int* token_ids_buf, int batch_size, int step);
...@@ -155,7 +156,6 @@ private: ...@@ -155,7 +156,6 @@ private:
const size_t inter_size_; const size_t inter_size_;
const size_t num_layer_; const size_t num_layer_;
const size_t vocab_size_; const size_t vocab_size_;
const size_t rotary_embedding_dim_;
float rmsnorm_eps_ = 1e-6f; float rmsnorm_eps_ = 1e-6f;
static constexpr bool neox_rotary_style_ = false; static constexpr bool neox_rotary_style_ = false;
......
// Copyright (c) OpenMMLab. All rights reserved.
#pragma once
namespace turbomind {
struct LlamaAttentionParams {
int rotray_embedding_dim;
int max_position_embeddings;
bool use_dynamic_ntk;
bool use_logn_attn;
};
} // namespace turbomind
...@@ -122,7 +122,6 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size, ...@@ -122,7 +122,6 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
inter_size_ = reader.GetInteger("llama", "inter_size"); inter_size_ = reader.GetInteger("llama", "inter_size");
num_layer_ = reader.GetInteger("llama", "num_layer"); num_layer_ = reader.GetInteger("llama", "num_layer");
vocab_size_ = reader.GetInteger("llama", "vocab_size"); vocab_size_ = reader.GetInteger("llama", "vocab_size");
rotary_embedding_dim_ = reader.GetInteger("llama", "rotary_embedding");
norm_eps_ = reader.GetFloat("llama", "norm_eps"); norm_eps_ = reader.GetFloat("llama", "norm_eps");
start_id_ = reader.GetInteger("llama", "start_id"); start_id_ = reader.GetInteger("llama", "start_id");
end_id_ = reader.GetInteger("llama", "end_id"); end_id_ = reader.GetInteger("llama", "end_id");
...@@ -137,6 +136,11 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size, ...@@ -137,6 +136,11 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
quant_policy_ = reader.GetInteger("llama", "quant_policy", 0); quant_policy_ = reader.GetInteger("llama", "quant_policy", 0);
group_size_ = reader.GetInteger("llama", "group_size", 0); group_size_ = reader.GetInteger("llama", "group_size", 0);
attn_params_.rotray_embedding_dim = reader.GetInteger("llama", "rotary_embedding");
attn_params_.max_position_embeddings = reader.GetInteger("llama", "max_position_embeddings", 0);
attn_params_.use_dynamic_ntk = reader.GetInteger("llama", "use_dynamic_ntk", 0);
attn_params_.use_logn_attn = reader.GetInteger("llama", "use_logn_attn", 0);
handleMissingParams(); handleMissingParams();
if (max_context_token_num_ <= max_batch_size_) { if (max_context_token_num_ <= max_batch_size_) {
...@@ -222,7 +226,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -222,7 +226,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
inter_size_, inter_size_,
num_layer_, num_layer_,
vocab_size_, vocab_size_,
rotary_embedding_dim_, attn_params_,
norm_eps_, norm_eps_,
max_batch_size_, max_batch_size_,
max_context_token_num_, max_context_token_num_,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#pragma once #pragma once
#include "src/turbomind/models/llama/LlamaV2.h" #include "src/turbomind/models/llama/LlamaV2.h"
#include "src/turbomind/models/llama/llama_params.h"
#include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/turbomind/triton_backend/llama/LlamaTritonModelInstance.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp" #include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/utils/cuda_utils.h" #include "src/turbomind/utils/cuda_utils.h"
...@@ -73,29 +74,29 @@ private: ...@@ -73,29 +74,29 @@ private:
std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_params, std::pair<std::vector<ft::NcclParam>, std::vector<ft::NcclParam>> nccl_params,
std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr); std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr);
size_t head_num_; size_t head_num_;
size_t kv_head_num_; size_t kv_head_num_;
size_t size_per_head_; size_t size_per_head_;
size_t inter_size_; size_t inter_size_;
size_t num_layer_; size_t num_layer_;
size_t vocab_size_; size_t vocab_size_;
size_t rotary_embedding_dim_; turbomind::LlamaAttentionParams attn_params_;
float norm_eps_; float norm_eps_;
int max_batch_size_; int max_batch_size_;
int max_context_token_num_; int max_context_token_num_;
int session_len_; int session_len_;
int step_length_; int step_length_;
int start_id_; int start_id_;
int end_id_; int end_id_;
int cache_max_entry_count_; int cache_max_entry_count_;
int cache_chunk_size_; int cache_chunk_size_;
int use_context_fmha_; int use_context_fmha_;
size_t tensor_para_size_; size_t tensor_para_size_;
size_t pipeline_para_size_; size_t pipeline_para_size_;
ft::WeightType weight_type_; ft::WeightType weight_type_;
bool attn_bias_; bool attn_bias_;
int quant_policy_; int quant_policy_;
int group_size_; int group_size_;
// shared weights for each device // shared weights for each device
std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_; std::vector<std::shared_ptr<ft::LlamaWeight<T>>> shared_weights_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment