Unverified Commit 4903d3cc authored by Lyu Han's avatar Lyu Han Committed by GitHub
Browse files

Pad tok_embedding and output weights to make their shape divisible by TP (#285)

* Pad tok_embedding and output weights to make their shape divisible by TP

* update

* update

* update

* update

* update llamaBatch
parent d5cb0be2
......@@ -147,11 +147,22 @@ def export(model_name: str,
attn_bias = False
inter_size = 0
tok_embeddings = model_params['tok_embeddings.weight']
_vocab_size, dim = tok_embeddings.shape
head_num = dim // size_per_head
if _vocab_size % tp != 0:
# Resolve https://github.com/InternLM/lmdeploy/issues/266
# Pad tok_embeddings and output weights, making their shape divisible by TP # noqa: E501
pad_size = (_vocab_size + tp - 1) // tp * tp - _vocab_size
# Pad weight at the bottom of dim 0
model_params['tok_embeddings.weight'] = torch.nn.functional.pad(
tok_embeddings, (0, 0, 0, pad_size), 'constant', 0)
# Pad output weight at the bottom of dim 0
model_params['output.weight'] = torch.nn.functional.pad(
model_params['output.weight'], (0, 0, 0, pad_size), 'constant', 0)
# reverse the splitting axes since the weights are transposed above
for param_name, param_data in model_params.items():
if param_name == 'tok_embeddings.weight':
_vocab_size, dim = param_data.shape
head_num = dim // size_per_head
split_dim = None
key, ext = param_name.split('.')[-2:]
if key == 'w_qkv' and ext == 'bias':
......
......@@ -152,7 +152,7 @@ void LlamaBatch<T>::allocateBuffer(size_t batch_size, size_t session_len)
const size_t batchxbeam = batch_size;
const size_t hidden_units = llama_->hidden_units_;
const size_t vocab_size = llama_->vocab_size_;
const size_t vocab_size = llama_->vocab_size_padded_;
context_decoder_input_buf_ =
(T*)allocator_->reMalloc(context_decoder_input_buf_, sizeof(T) * max_context_token_num_ * hidden_units, false);
......@@ -899,11 +899,11 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_ * max_context_token_num_);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) {
FT_CHECK(llama_->vocab_size_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_ / tp;
FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
local_context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * local_vocab_size * max_context_token_num_);
}
......@@ -921,7 +921,7 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
cudaMemcpyDefault,
stream_));
}
logits += llama_->vocab_size_ * lengths[k];
logits += llama_->vocab_size_padded_ * lengths[k];
}
}
......
......@@ -71,6 +71,7 @@ LlamaV2<T>::LlamaV2(size_t head_num,
inter_size_(inter_size),
num_layer_(num_layer),
vocab_size_(vocab_size),
vocab_size_padded_(vocab_size),
rmsnorm_eps_(norm_eps),
start_id_(start_id),
end_id_(end_id),
......@@ -90,9 +91,10 @@ LlamaV2<T>::LlamaV2(size_t head_num,
{
TM_LOG_DEBUG(__PRETTY_FUNCTION__);
FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0);
TM_LOG_INFO("NCCL group_id = %d", tensor_para_.group_id_);
vocab_size_padded_ = (vocab_size_padded_ + tensor_para_.world_size_ - 1) / tensor_para_.world_size_ * tensor_para_.world_size_;
size_t elem_bits = 0;
if (quant_policy & QuantPolicy::kCacheKVInt8) {
elem_bits = sizeof(int8_t) * 8;
......@@ -168,7 +170,7 @@ void LlamaV2<T>::initialize(const LlamaAttentionParams& attn_params,
quant_policy);
dynamic_decode_layer_ = new DynamicDecodeLayer<float>(vocab_size_,
vocab_size_, // vocab_size_padded,
vocab_size_padded_,
0, // end_id, deprecated
stream_,
cublas_wrapper_,
......@@ -333,8 +335,8 @@ void LlamaV2<T>::postDecodeEmbedding(float* logits, float* local_logits, const T
cublasGemmAlgo_t(-1));
}
else {
FT_CHECK(vocab_size_ % tensor_para_.world_size_ == 0);
const size_t local_vocab_size = vocab_size_ / tensor_para_.world_size_;
FT_CHECK(vocab_size_padded_ % tensor_para_.world_size_ == 0);
const size_t local_vocab_size = vocab_size_padded_ / tensor_para_.world_size_;
cublas_wrapper_->Gemm(CUBLAS_OP_T,
CUBLAS_OP_N,
local_vocab_size, // n
......@@ -389,7 +391,7 @@ void LlamaV2<T>::dynamicDecode(int* token_ids,
int local_batch_size = (int)batch_size;
std::unordered_map<std::string, Tensor> dynamic_decode_input_tensors{
{"logits", {MEMORY_GPU, TYPE_FP32, {batch_size, (size_t)1, vocab_size_}, logits}},
{"logits", {MEMORY_GPU, TYPE_FP32, {batch_size, (size_t)1, vocab_size_padded_}, logits}},
{"step", {MEMORY_CPU, TYPE_INT32, {1}, &step}},
{"max_input_length", {MEMORY_CPU, TYPE_INT32, {1}, &max_context_len}},
{"sequence_limit_length", {MEMORY_GPU, TYPE_UINT32, {batch_size}, seq_limit_len}},
......
......@@ -156,6 +156,7 @@ private:
const size_t inter_size_;
const size_t num_layer_;
const size_t vocab_size_;
size_t vocab_size_padded_;
float rmsnorm_eps_ = 1e-6f;
static constexpr bool neox_rotary_style_ = false;
......
......@@ -37,11 +37,16 @@ LlamaWeight<T>::LlamaWeight(size_t head_num,
hidden_units_(head_num * size_per_head),
inter_size_(inter_size),
vocab_size_(vocab_size),
vocab_size_padded_(vocab_size),
num_layer_(num_layer),
weight_type_(weight_type),
tensor_para_size_(tensor_para_size),
tensor_para_rank_(tensor_para_rank)
{
if (vocab_size_padded_ % tensor_para_size_ != 0) {
vocab_size_padded_ = (vocab_size_padded_ + tensor_para_size_ - 1) / tensor_para_size_ * tensor_para_size_;
TM_LOG_WARNING("pad vocab size from %d to %d", vocab_size_, vocab_size_padded_);
}
decoder_layer_weights.reserve(num_layer_);
for (unsigned l = 0; l < num_layer_; ++l) {
decoder_layer_weights.push_back(new LlamaDecoderLayerWeight<T>(head_num,
......@@ -72,9 +77,9 @@ LlamaWeight<T>::~LlamaWeight()
template<typename T>
void LlamaWeight<T>::mallocWeights()
{
deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_ * hidden_units_);
deviceMalloc((T**)&pre_decoder_embedding_table, vocab_size_padded_ * hidden_units_);
deviceMalloc((T**)&output_norm_weight, hidden_units_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_);
deviceMalloc((T**)&post_decoder_embedding_kernel, hidden_units_ * vocab_size_padded_);
}
template<typename T>
......@@ -84,14 +89,14 @@ void LlamaWeight<T>::loadModel(std::string dir_path)
dir_path += '/';
loadWeightFromBin((T*)pre_decoder_embedding_table,
{vocab_size_ * hidden_units_},
{vocab_size_padded_ * hidden_units_},
dir_path + "tok_embeddings.weight",
model_file_type);
loadWeightFromBin((T*)output_norm_weight, {hidden_units_}, dir_path + "norm.weight", model_file_type);
loadWeightFromBin(
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_}, dir_path + "output.weight", model_file_type);
(T*)post_decoder_embedding_kernel, {hidden_units_ * vocab_size_padded_}, dir_path + "output.weight", model_file_type);
for (unsigned layer = 0; layer < num_layer_; ++layer) {
decoder_layer_weights[layer]->loadModel(dir_path + "layers." + std::to_string(layer), model_file_type);
......
......@@ -58,6 +58,7 @@ private:
size_t hidden_units_;
size_t inter_size_;
size_t vocab_size_;
size_t vocab_size_padded_;
size_t num_layer_;
WeightType weight_type_;
size_t tensor_para_size_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment