Unverified Commit f07b697b authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Support Llama-2 with GQA (#147)

* add GQA for llama2

* fix model conversion

* fix lint & remove dev log

* update news

* minor

* fix allocation size

* fix split_dim for w_qkv.bias
parent 2a475478
...@@ -23,7 +23,9 @@ ...@@ -23,7 +23,9 @@
namespace turbomind { namespace turbomind {
template<typename T> template<typename T>
LlamaWeight<T>::LlamaWeight(size_t hidden_units, LlamaWeight<T>::LlamaWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t vocab_size, size_t vocab_size,
size_t num_layer, size_t num_layer,
...@@ -32,7 +34,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units, ...@@ -32,7 +34,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
size_t tensor_para_size, size_t tensor_para_size,
size_t tensor_para_rank, size_t tensor_para_rank,
int prefix_cache_len): int prefix_cache_len):
hidden_units_(hidden_units), hidden_units_(head_num * size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
vocab_size_(vocab_size), vocab_size_(vocab_size),
num_layer_(num_layer), num_layer_(num_layer),
...@@ -43,8 +45,14 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units, ...@@ -43,8 +45,14 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
{ {
decoder_layer_weights.reserve(num_layer_); decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) { for (unsigned l = 0; l < num_layer_; ++l) {
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>( decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(head_num,
hidden_units_, inter_size_, weight_type_, attn_bias, tensor_para_size_, tensor_para_rank_)); kv_head_num,
size_per_head,
inter_size_,
weight_type_,
attn_bias,
tensor_para_size_,
tensor_para_rank_));
} }
mallocWeights(); mallocWeights();
......
...@@ -28,7 +28,9 @@ namespace turbomind { ...@@ -28,7 +28,9 @@ namespace turbomind {
template<typename T> template<typename T>
struct LlamaWeight { struct LlamaWeight {
LlamaWeight() = default; LlamaWeight() = default;
LlamaWeight(size_t hidden_units, LlamaWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size, size_t inter_size,
size_t vocab_size, size_t vocab_size,
size_t num_layer, size_t num_layer,
......
...@@ -488,6 +488,7 @@ __global__ void transpose_value_cache(T* v_dst, // ...@@ -488,6 +488,7 @@ __global__ void transpose_value_cache(T* v_dst, //
const T** v_src, const T** v_src,
const size_t src_offset, const size_t src_offset,
const int head_num, const int head_num,
const int head_n_rep,
const int size_per_head, const int size_per_head,
const int* seq_length, const int* seq_length,
const int max_kv_len, const int max_kv_len,
...@@ -511,9 +512,9 @@ __global__ void transpose_value_cache(T* v_dst, // ...@@ -511,9 +512,9 @@ __global__ void transpose_value_cache(T* v_dst, //
if (v_seq_len_id < seq_len) { if (v_seq_len_id < seq_len) {
// [B, H, s, D/x] <- [B, H, S[:s], D/x] // [B, H, s, D/x] <- [B, H, S[:s], D/x]
const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H
v_seq_len_id * size_per_head_div_x + // s v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x v_head_size_id; // D/x
const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B
head_id * size_per_head_div_x * max_kv_len + // H head_id * size_per_head_div_x * max_kv_len + // H
...@@ -529,6 +530,7 @@ __global__ void transpose_value_cache_int8(T* v_dst, // ...@@ -529,6 +530,7 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
const int8_t** v_src, const int8_t** v_src,
const size_t src_offset, const size_t src_offset,
const int head_num, const int head_num,
const int head_n_rep,
const int size_per_head, const int size_per_head,
const int* seq_length, const int* seq_length,
const int max_kv_len, const int max_kv_len,
...@@ -553,9 +555,9 @@ __global__ void transpose_value_cache_int8(T* v_dst, // ...@@ -553,9 +555,9 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
if (v_seq_len_id < seq_len) { if (v_seq_len_id < seq_len) {
// [B, H, s, D/x] <- [B, H, S[:s], D/x] // [B, H, s, D/x] <- [B, H, S[:s], D/x]
const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H
v_seq_len_id * size_per_head_div_x + // s v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x v_head_size_id; // D/x
const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B
head_id * size_per_head_div_x * max_kv_len + // H head_id * size_per_head_div_x * max_kv_len + // H
...@@ -583,6 +585,7 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -583,6 +585,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
int max_seq_len, int max_seq_len,
int size_per_head, int size_per_head,
int head_num, int head_num,
int head_n_rep,
cudaStream_t stream, cudaStream_t stream,
int quant, int quant,
const float* kv_scale) const float* kv_scale)
...@@ -597,6 +600,7 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -597,6 +600,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
reinterpret_cast<const int8_t**>(key_cache), reinterpret_cast<const int8_t**>(key_cache),
src_offset, src_offset,
head_num, head_num,
head_n_rep,
size_per_head, size_per_head,
key_length, key_length,
max_kv_len, max_kv_len,
...@@ -607,6 +611,7 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -607,6 +611,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
reinterpret_cast<const int8_t**>(val_cache), reinterpret_cast<const int8_t**>(val_cache),
src_offset, src_offset,
head_num, head_num,
head_n_rep,
size_per_head, size_per_head,
key_length, key_length,
max_kv_len, max_kv_len,
...@@ -614,11 +619,25 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -614,11 +619,25 @@ void invokeTransposeKVCache(T* key_cache_trans,
kv_scale[1]); kv_scale[1]);
} }
else { else {
transpose_value_cache<<<grid, block_sz, 0, stream>>>( transpose_value_cache<<<grid, block_sz, 0, stream>>>(key_cache_trans,
key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); key_cache,
src_offset,
transpose_value_cache<<<grid, block_sz, 0, stream>>>( head_num,
val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len); head_n_rep,
size_per_head,
key_length,
max_kv_len,
max_seq_len);
transpose_value_cache<<<grid, block_sz, 0, stream>>>(val_cache_trans,
val_cache,
src_offset,
head_num,
head_n_rep,
size_per_head,
key_length,
max_kv_len,
max_seq_len);
} }
} }
...@@ -633,6 +652,7 @@ template void invokeTransposeKVCache(float*, ...@@ -633,6 +652,7 @@ template void invokeTransposeKVCache(float*,
int, int,
int, int,
int, int,
int,
cudaStream_t stream, cudaStream_t stream,
int, int,
const float*); const float*);
...@@ -647,6 +667,7 @@ template void invokeTransposeKVCache(half*, ...@@ -647,6 +667,7 @@ template void invokeTransposeKVCache(half*,
int, int,
int, int,
int, int,
int,
cudaStream_t stream, cudaStream_t stream,
int, int,
const float*); const float*);
......
...@@ -62,6 +62,7 @@ void invokeTransposeKVCache(T* key_cache_trans, ...@@ -62,6 +62,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
int max_seq_len, int max_seq_len,
int size_per_head, int size_per_head,
int head_num, int head_num,
int head_n_rep,
cudaStream_t stream, cudaStream_t stream,
int quant_policy, int quant_policy,
const float* kv_scale); const float* kv_scale);
......
...@@ -59,6 +59,11 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM ...@@ -59,6 +59,11 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
template<typename T> template<typename T>
void LlamaTritonModel<T>::handleMissingParams() void LlamaTritonModel<T>::handleMissingParams()
{ {
if (kv_head_num_ == 0) {
kv_head_num_ = head_num_;
TM_LOG_WARNING("[LlamaTritonModel] `kv_head_num` is not set, default to `head_num` (%d).", (int)kv_head_num_);
}
if (!max_batch_size_) { if (!max_batch_size_) {
max_batch_size_ = 32; max_batch_size_ = 32;
TM_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", (int)max_batch_size_); TM_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", (int)max_batch_size_);
...@@ -112,6 +117,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size, ...@@ -112,6 +117,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
model_name_ = reader.Get("llama", "model_name"); model_name_ = reader.Get("llama", "model_name");
head_num_ = reader.GetInteger("llama", "head_num"); head_num_ = reader.GetInteger("llama", "head_num");
kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0);
size_per_head_ = reader.GetInteger("llama", "size_per_head"); size_per_head_ = reader.GetInteger("llama", "size_per_head");
inter_size_ = reader.GetInteger("llama", "inter_size"); inter_size_ = reader.GetInteger("llama", "inter_size");
num_layer_ = reader.GetInteger("llama", "num_layer"); num_layer_ = reader.GetInteger("llama", "num_layer");
...@@ -211,6 +217,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh ...@@ -211,6 +217,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
ft::FT_CHECK(pipeline_para.world_size_ = pipeline_para_size_); ft::FT_CHECK(pipeline_para.world_size_ = pipeline_para_size_);
auto llama = std::make_unique<ft::LlamaV2<T>>(head_num_, auto llama = std::make_unique<ft::LlamaV2<T>>(head_num_,
kv_head_num_,
size_per_head_, size_per_head_,
inter_size_, inter_size_,
num_layer_, num_layer_,
...@@ -283,7 +290,9 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank) ...@@ -283,7 +290,9 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
const int tensor_para_rank = rank % tensor_para_size_; const int tensor_para_rank = rank % tensor_para_size_;
const int pipeline_para_rank = rank / tensor_para_size_; const int pipeline_para_rank = rank / tensor_para_size_;
ft::FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0); ft::FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0);
shared_weights_[device_id] = std::make_shared<ft::LlamaWeight<T>>(head_num_ * size_per_head_, shared_weights_[device_id] = std::make_shared<ft::LlamaWeight<T>>(head_num_,
kv_head_num_,
size_per_head_,
inter_size_, inter_size_,
vocab_size_, vocab_size_,
num_layer_, num_layer_,
...@@ -301,16 +310,16 @@ std::string LlamaTritonModel<T>::toString() ...@@ -301,16 +310,16 @@ std::string LlamaTritonModel<T>::toString()
{ {
std::stringstream ss; std::stringstream ss;
ss << "Model: " ss << "Model: "
<< "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_ << "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ << "\nsize_per_head: " << size_per_head_
<< "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nattn_bias: " << attn_bias_ << "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_
<< "\nmax_batch_size: " << max_batch_size_ << "\nmax_context_token_num: " << max_context_token_num_ << "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << max_batch_size_
<< "\nsession_len: " << session_len_ << "\nstep_length: " << step_length_ << "\nmax_context_token_num: " << max_context_token_num_ << "\nsession_len: " << session_len_
<< "\ncache_max_entry_count: " << cache_max_entry_count_ << "\ncache_chunk_size: " << cache_chunk_size_ << "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_entry_count_
<< "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_ << "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_
<< "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_ << "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_
<< "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_ << "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_
<< "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_ << "\nmodel_name: " << model_name_ << "\nprefix_cache_len: " << prefix_cache_len_
<< "\nquant_policy: " << quant_policy_ << std::endl; << "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl;
return ss.str(); return ss.str();
} }
......
...@@ -74,6 +74,7 @@ private: ...@@ -74,6 +74,7 @@ private:
std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr); std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr);
size_t head_num_; size_t head_num_;
size_t kv_head_num_;
size_t size_per_head_; size_t size_per_head_;
size_t inter_size_; size_t inter_size_;
size_t num_layer_; size_t num_layer_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment