Unverified Commit 4903d3cc authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

Pad tok_embedding and output weights to make their shape divisible by TP (#285)

* Pad tok_embedding and output weights to make their shape divisible by TP

* update

* update

* update

* update

* update llamaBatch
parent d5cb0be2
...@@ -147,11 +147,22 @@ def export(model_name: str, ...@@ -147,11 +147,22 @@ def export(model_name: str,
attn_bias = False attn_bias = False
inter_size = 0 inter_size = 0
tok_embeddings = model_params['tok_embeddings.weight']
_vocab_size, dim = tok_embeddings.shape
head_num = dim // size_per_head
if _vocab_size % tp != 0:
# Resolve https://github.com/InternLM/lmdeploy/issues/266
# Pad tok_embeddings and output weights, making their shape divisible by TP # noqa: E501
pad_size = (_vocab_size + tp - 1) // tp * tp - _vocab_size
# Pad weight at the bottom of dim 0
model_params['tok_embeddings.weight'] = torch.nn.functional.pad(
tok_embeddings, (0, 0, 0, pad_size), 'constant', 0)
# Pad output weight at the bottom of dim 0
model_params['output.weight'] = torch.nn.functional.pad(
model_params['output.weight'], (0, 0, 0, pad_size), 'constant', 0)
# reverse the splitting axes since the weights are transposed above # reverse the splitting axes since the weights are transposed above
for param_name, param_data in model_params.items(): for param_name, param_data in model_params.items():
if param_name == 'tok_embeddings.weight':
_vocab_size, dim = param_data.shape
head_num = dim // size_per_head
split_dim = None split_dim = None
key, ext = param_name.split('.')[-2:] key, ext = param_name.split('.')[-2:]
if key == 'w_qkv' and ext == 'bias': if key == 'w_qkv' and ext == 'bias':
......
...@@ -152,7 +152,7 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len) ...@@ -152,7 +152,7 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
const size_t batchxbeam = batch_size; const size_t batchxbeam = batch_size;
const size_t hidden_units = llama_->hidden_units_; const size_t hidden_units = llama_->hidden_units_;
const size_t vocab_size = llama_->vocab_size_; const size_t vocab_size = llama_->vocab_size_padded_;
context_decoder_input_buf_ = context_decoder_input_buf_ =
(T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false); (T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
...@@ -899,11 +899,11 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_ ...@@ -899,11 +899,11 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
if (context_logits_buf_ == nullptr) { if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true); NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_ * max_context_token_num_); context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_; const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) { if (tp > 1) {
FT_CHECK(llama_->vocab_size_ % tp == 0); FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_ / tp; const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
local_context_logits_buf_ = local_context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_); (float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
} }
...@@ -921,7 +921,7 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_ ...@@ -921,7 +921,7 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
cudaMemcpyDefault, cudaMemcpyDefault,
stream_)); stream_));
} }
logits += llama_->vocab_size_ * lengths[k]; logits += llama_->vocab_size_padded_ * lengths[k];
} }
} }
......
...@@ -71,6 +71,7 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -71,6 +71,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
inter_size_(inter_size), inter_size_(inter_size),
num_layer_(num_layer), num_layer_(num_layer),
vocab_size_(vocab_size), vocab_size_(vocab_size),
vocab_size_padded_(vocab_size),
rmsnorm_eps_(norm_eps), rmsnorm_eps_(norm_eps),
start_id_(start_id), start_id_(start_id),
end_id_(end_id), end_id_(end_id),
...@@ -90,9 +91,10 @@ LlamaV2<T>::LlamaV2(size_t head_num, ...@@ -90,9 +91,10 @@ LlamaV2<T>::LlamaV2(size_t head_num,
{ {
TM_LOG_DEBUG(__PRETTY_FUNCTION__); TM_LOG_DEBUG(__PRETTY_FUNCTION__);
FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0);
TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_); TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);
vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
size_t elem_bits = 0; size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) { if (quant_policy & QuantPolicy::kCacheKVInt8) {
elem_bits = sizeof(int8_t) * 8; elem_bits = sizeof(int8_t) * 8;
...@@ -168,7 +170,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params, ...@@ -168,7 +170,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
quant_policy); quant_policy);
dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_, dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
vocab_size_, // vocab_size_padded, vocab_size_padded_,
0, // end_id, deprecated 0, // end_id, deprecated
stream_, stream_,
cublas_wrapper_, cublas_wrapper_,
...@@ -333,8 +335,8 @@ void LlamaV2<T>::postDecodeEmbedding(float* logits, float* local_logits, const T ...@@ -333,8 +335,8 @@ void LlamaV2<T>::postDecodeEmbedding(float* logits, float* local_logits, const T
cublasGemmAlgo_t(-1)); cublasGemmAlgo_t(-1));
} }
else { else {
FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0); FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0);
const size_t local_vocab_size = vocab_size_ / tensor_para_.world_size_; const size_t local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_;
cublas_wrapper_->Gemm(CUBLAS_OP_T, cublas_wrapper_->Gemm(CUBLAS_OP_T,
CUBLAS_OP_N, CUBLAS_OP_N,
local_vocab_size, // n local_vocab_size, // n
...@@ -389,7 +391,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids, ...@@ -389,7 +391,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
int local_batch_size = (int)batch_size; int local_batch_size = (int)batch_size;
std::unordered_map<std::string, Tensor> dynamic_decode_input_tensors{ std::unordered_map<std::string, Tensor> dynamic_decode_input_tensors{
{"logits", {MEMORY_GPU, TYPE_FP32, {batch_size, (size_t)1, vocab_size_}, logits}}, {"logits", {MEMORY_GPU, TYPE_FP32, {batch_size, (size_t)1, vocab_size_padded_}, logits}},
{"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}}, {"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}},
{"max_input_length", {MEMORY_CPU, TYPE_INT32, {1}, &max_context_len}}, {"max_input_length", {MEMORY_CPU, TYPE_INT32, {1}, &max_context_len}},
{"sequence_limit_length", {MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len}}, {"sequence_limit_length", {MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len}},
......
...@@ -156,6 +156,7 @@ private: ...@@ -156,6 +156,7 @@ private:
const size_t inter_size_; const size_t inter_size_;
const size_t num_layer_; const size_t num_layer_;
const size_t vocab_size_; const size_t vocab_size_;
size_t vocab_size_padded_;
float rmsnorm_eps_ = 1e-6f; float rmsnorm_eps_ = 1e-6f;
static constexpr bool neox_rotary_style_ = false; static constexpr bool neox_rotary_style_ = false;
......
...@@ -37,11 +37,16 @@ LlamaWeight<T>::LlamaWeight(size_t head_num, ...@@ -37,11 +37,16 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
hidden_units_(head_num * size_per_head), hidden_units_(head_num * size_per_head),
inter_size_(inter_size), inter_size_(inter_size),
vocab_size_(vocab_size), vocab_size_(vocab_size),
vocab_size_padded_(vocab_size),
num_layer_(num_layer), num_layer_(num_layer),
weight_type_(weight_type), weight_type_(weight_type),
tensor_para_size_(tensor_para_size), tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank) tensor_para_rank_(tensor_para_rank)
{ {
if (vocab_size_padded_ % tensor_para_size_ != 0) {
vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_;
TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_);
}
decoder_layer_weights.reserve(num_layer_); decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) { for (unsigned l = 0; l < num_layer_; ++l) {
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(head_num, decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(head_num,
...@@ -72,9 +77,9 @@ LlamaWeight<T>::~LlamaWeight() ...@@ -72,9 +77,9 @@ LlamaWeight<T>::~LlamaWeight()
template<typename T> template<typename T>
void LlamaWeight<T>::mallocWeights() void LlamaWeight<T>::mallocWeights()
{ {
deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_ * hidden_units_); deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_);
deviceMalloc((T**)&output_norm_weight, hidden_units_); deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_); deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_);
} }
template<typename T> template<typename T>
...@@ -84,14 +89,14 @@ void LlamaWeight<T>::loadModel(std::string dir_path) ...@@ -84,14 +89,14 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
dir_path += '/'; dir_path += '/';
loadWeightFromBin((T*)pre_decoder_embedding_table, loadWeightFromBin((T*)pre_decoder_embedding_table,
{vocab_size_ * hidden_units_}, {vocab_size_padded_ * hidden_units_},
dir_path + "tok_embeddings.weight", dir_path + "tok_embeddings.weight",
model_file_type); model_file_type);
loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type); loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);
loadWeightFromBin( loadWeightFromBin(
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_}, dir_path + "output.weight", model_file_type); (T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type);
for (unsigned layer = 0; layer < num_layer_; ++layer) { for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type); decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
......
...@@ -58,6 +58,7 @@ private: ...@@ -58,6 +58,7 @@ private:
size_t hidden_units_; size_t hidden_units_;
size_t inter_size_; size_t inter_size_;
size_t vocab_size_; size_t vocab_size_;
size_t vocab_size_padded_;
size_t num_layer_; size_t num_layer_;
WeightType weight_type_; WeightType weight_type_;
size_t tensor_para_size_; size_t tensor_para_size_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment