Unverified Commit f07b697b authored by Li Zhang's avatar Li Zhang Committed by GitHub
Browse files

[Feature] Support Llama-2 with GQA (#147)

* add GQA for llama2

* fix model conversion

* fix lint & remove dev log

* update news

* minor

* fix allocation size

* fix split_dim for w_qkv.bias
parent 2a475478
......@@ -23,7 +23,9 @@
namespace turbomind {
template<typename T>
LlamaWeight<T>::LlamaWeight(size_t hidden_units,
LlamaWeight<T>::LlamaWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t vocab_size,
size_t num_layer,
......@@ -32,7 +34,7 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
size_t tensor_para_size,
size_t tensor_para_rank,
int prefix_cache_len):
hidden_units_(hidden_units),
hidden_units_(head_num * size_per_head),
inter_size_(inter_size),
vocab_size_(vocab_size),
num_layer_(num_layer),
......@@ -43,8 +45,14 @@ LlamaWeight<T>::LlamaWeight(size_t hidden_units,
{
decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) {
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(
hidden_units_, inter_size_, weight_type_, attn_bias, tensor_para_size_, tensor_para_rank_));
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(head_num,
kv_head_num,
size_per_head,
inter_size_,
weight_type_,
attn_bias,
tensor_para_size_,
tensor_para_rank_));
}
mallocWeights();
......
......@@ -28,7 +28,9 @@ namespace turbomind {
template<typename T>
struct LlamaWeight {
LlamaWeight() = default;
LlamaWeight(size_t hidden_units,
LlamaWeight(size_t head_num,
size_t kv_head_num,
size_t size_per_head,
size_t inter_size,
size_t vocab_size,
size_t num_layer,
......
......@@ -488,6 +488,7 @@ __global__ void transpose_value_cache(T* v_dst, //
const T** v_src,
const size_t src_offset,
const int head_num,
const int head_n_rep,
const int size_per_head,
const int* seq_length,
const int max_kv_len,
......@@ -511,9 +512,9 @@ __global__ void transpose_value_cache(T* v_dst, //
if (v_seq_len_id < seq_len) {
// [B, H, s, D/x] <- [B, H, S[:s], D/x]
const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H
v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x
const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H
v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x
const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B
head_id * size_per_head_div_x * max_kv_len + // H
......@@ -529,6 +530,7 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
const int8_t** v_src,
const size_t src_offset,
const int head_num,
const int head_n_rep,
const int size_per_head,
const int* seq_length,
const int max_kv_len,
......@@ -553,9 +555,9 @@ __global__ void transpose_value_cache_int8(T* v_dst, //
if (v_seq_len_id < seq_len) {
// [B, H, s, D/x] <- [B, H, S[:s], D/x]
const int64_t src_idx = head_id * size_per_head_div_x * max_seq_len + // H
v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x
const int64_t src_idx = head_id / head_n_rep * size_per_head_div_x * max_seq_len + // H
v_seq_len_id * size_per_head_div_x + // s
v_head_size_id; // D/x
const int64_t dst_idx = batch_id * head_num * size_per_head_div_x * max_kv_len + // B
head_id * size_per_head_div_x * max_kv_len + // H
......@@ -583,6 +585,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
int max_seq_len,
int size_per_head,
int head_num,
int head_n_rep,
cudaStream_t stream,
int quant,
const float* kv_scale)
......@@ -597,6 +600,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
reinterpret_cast<const int8_t**>(key_cache),
src_offset,
head_num,
head_n_rep,
size_per_head,
key_length,
max_kv_len,
......@@ -607,6 +611,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
reinterpret_cast<const int8_t**>(val_cache),
src_offset,
head_num,
head_n_rep,
size_per_head,
key_length,
max_kv_len,
......@@ -614,11 +619,25 @@ void invokeTransposeKVCache(T* key_cache_trans,
kv_scale[1]);
}
else {
transpose_value_cache<<<grid, block_sz, 0, stream>>>(
key_cache_trans, key_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len);
transpose_value_cache<<<grid, block_sz, 0, stream>>>(
val_cache_trans, val_cache, src_offset, head_num, size_per_head, key_length, max_kv_len, max_seq_len);
transpose_value_cache<<<grid, block_sz, 0, stream>>>(key_cache_trans,
key_cache,
src_offset,
head_num,
head_n_rep,
size_per_head,
key_length,
max_kv_len,
max_seq_len);
transpose_value_cache<<<grid, block_sz, 0, stream>>>(val_cache_trans,
val_cache,
src_offset,
head_num,
head_n_rep,
size_per_head,
key_length,
max_kv_len,
max_seq_len);
}
}
......@@ -633,6 +652,7 @@ template void invokeTransposeKVCache(float*,
int,
int,
int,
int,
cudaStream_t stream,
int,
const float*);
......@@ -647,6 +667,7 @@ template void invokeTransposeKVCache(half*,
int,
int,
int,
int,
cudaStream_t stream,
int,
const float*);
......
......@@ -62,6 +62,7 @@ void invokeTransposeKVCache(T* key_cache_trans,
int max_seq_len,
int size_per_head,
int head_num,
int head_n_rep,
cudaStream_t stream,
int quant_policy,
const float* kv_scale);
......
......@@ -59,6 +59,11 @@ std::shared_ptr<AbstractTransformerModel> AbstractTransformerModel::createLlamaM
template<typename T>
void LlamaTritonModel<T>::handleMissingParams()
{
if (kv_head_num_ == 0) {
kv_head_num_ = head_num_;
TM_LOG_WARNING("[LlamaTritonModel] `kv_head_num` is not set, default to `head_num` (%d).", (int)kv_head_num_);
}
if (!max_batch_size_) {
max_batch_size_ = 32;
TM_LOG_WARNING("[LlamaTritonModel] `max_batch_size` is not set, default to %d.", (int)max_batch_size_);
......@@ -112,6 +117,7 @@ LlamaTritonModel<T>::LlamaTritonModel(size_t tensor_para_size,
model_name_ = reader.Get("llama", "model_name");
head_num_ = reader.GetInteger("llama", "head_num");
kv_head_num_ = reader.GetInteger("llama", "kv_head_num", 0);
size_per_head_ = reader.GetInteger("llama", "size_per_head");
inter_size_ = reader.GetInteger("llama", "inter_size");
num_layer_ = reader.GetInteger("llama", "num_layer");
......@@ -211,6 +217,7 @@ std::unique_ptr<LlamaTritonSharedModelInstance<T>> LlamaTritonModel<T>::createSh
ft::FT_CHECK(pipeline_para.world_size_ = pipeline_para_size_);
auto llama = std::make_unique<ft::LlamaV2<T>>(head_num_,
kv_head_num_,
size_per_head_,
inter_size_,
num_layer_,
......@@ -283,7 +290,9 @@ void LlamaTritonModel<T>::createSharedWeights(int device_id, int rank)
const int tensor_para_rank = rank % tensor_para_size_;
const int pipeline_para_rank = rank / tensor_para_size_;
ft::FT_CHECK(pipeline_para_size_ == 1 && pipeline_para_rank == 0);
shared_weights_[device_id] = std::make_shared<ft::LlamaWeight<T>>(head_num_ * size_per_head_,
shared_weights_[device_id] = std::make_shared<ft::LlamaWeight<T>>(head_num_,
kv_head_num_,
size_per_head_,
inter_size_,
vocab_size_,
num_layer_,
......@@ -301,16 +310,16 @@ std::string LlamaTritonModel<T>::toString()
{
std::stringstream ss;
ss << "Model: "
<< "\nhead_num: " << head_num_ << "\nsize_per_head: " << size_per_head_ << "\ninter_size: " << inter_size_
<< "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_ << "\nattn_bias: " << attn_bias_
<< "\nmax_batch_size: " << max_batch_size_ << "\nmax_context_token_num: " << max_context_token_num_
<< "\nsession_len: " << session_len_ << "\nstep_length: " << step_length_
<< "\ncache_max_entry_count: " << cache_max_entry_count_ << "\ncache_chunk_size: " << cache_chunk_size_
<< "\nuse_context_fmha: " << use_context_fmha_ << "\nstart_id: " << start_id_
<< "\ntensor_para_size: " << tensor_para_size_ << "\npipeline_para_size: " << pipeline_para_size_
<< "\nenable_custom_all_reduce: " << enable_custom_all_reduce_ << "\nmodel_name: " << model_name_
<< "\nprefix_cache_len: " << prefix_cache_len_ << "\nmodel_dir: " << model_dir_
<< "\nquant_policy: " << quant_policy_ << std::endl;
<< "\nhead_num: " << head_num_ << "\nkv_head_num: " << kv_head_num_ << "\nsize_per_head: " << size_per_head_
<< "\ninter_size: " << inter_size_ << "\nnum_layer: " << num_layer_ << "\nvocab_size: " << vocab_size_
<< "\nattn_bias: " << attn_bias_ << "\nmax_batch_size: " << max_batch_size_
<< "\nmax_context_token_num: " << max_context_token_num_ << "\nsession_len: " << session_len_
<< "\nstep_length: " << step_length_ << "\ncache_max_entry_count: " << cache_max_entry_count_
<< "\ncache_chunk_size: " << cache_chunk_size_ << "\nuse_context_fmha: " << use_context_fmha_
<< "\nstart_id: " << start_id_ << "\ntensor_para_size: " << tensor_para_size_
<< "\npipeline_para_size: " << pipeline_para_size_ << "\nenable_custom_all_reduce: " << enable_custom_all_reduce_
<< "\nmodel_name: " << model_name_ << "\nprefix_cache_len: " << prefix_cache_len_
<< "\nmodel_dir: " << model_dir_ << "\nquant_policy: " << quant_policy_ << std::endl;
return ss.str();
}
......
......@@ -74,6 +74,7 @@ private:
std::shared_ptr<ft::AbstractCustomComm> custom_all_reduce_comm = nullptr);
size_t head_num_;
size_t kv_head_num_;
size_t size_per_head_;
size_t inter_size_;
size_t num_layer_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment