Unverified Commit 2dec28ae authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

Fix memory leak (#415)

parent ec034c15
...@@ -45,10 +45,10 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso ...@@ -45,10 +45,10 @@ void BaseSamplingLayer<T>::allocateBuffer(size_t batch_size, Tensor top_k, Tenso
reinterpret_cast<bool*>(allocator_->reMalloc(skip_decode_buf_, sizeof(bool) * batch_size, false)); reinterpret_cast<bool*>(allocator_->reMalloc(skip_decode_buf_, sizeof(bool) * batch_size, false));
// host buffers. // host buffers.
temperature_ = new float[batch_size]; temperature_ = (float*)std::realloc((void*)temperature_, batch_size * sizeof(float));
repetition_penalty_ = new float[batch_size]; repetition_penalty_ = (float*)std::realloc((void*)repetition_penalty_, batch_size * sizeof(float));
min_lengths_ = new int[batch_size]; min_lengths_ = (int*)std::realloc((void*)min_lengths_, batch_size * sizeof(int));
skip_decode_ = new bool[batch_size]; skip_decode_ = (bool*)std::realloc((void*)skip_decode_, batch_size * sizeof(bool));
is_allocate_buffer_ = true; is_allocate_buffer_ = true;
} }
...@@ -65,10 +65,10 @@ void BaseSamplingLayer<T>::freeBuffer() ...@@ -65,10 +65,10 @@ void BaseSamplingLayer<T>::freeBuffer()
allocator_->free((void**)(&min_lengths_buf_)); allocator_->free((void**)(&min_lengths_buf_));
allocator_->free((void**)(&runtime_logits_buf_)); allocator_->free((void**)(&runtime_logits_buf_));
allocator_->free((void**)(&skip_decode_buf_)); allocator_->free((void**)(&skip_decode_buf_));
delete[] temperature_; std::free(temperature_);
delete[] repetition_penalty_; std::free(repetition_penalty_);
delete[] min_lengths_; std::free(min_lengths_);
delete[] skip_decode_; std::free(skip_decode_);
is_allocate_buffer_ = false; is_allocate_buffer_ = false;
} }
} }
......
...@@ -61,8 +61,8 @@ std::unordered_map<std::string, ft::Tensor> LlamaTritonModelInstance<T>::convert ...@@ -61,8 +61,8 @@ std::unordered_map<std::string, ft::Tensor> LlamaTritonModelInstance<T>::convert
const size_t request_batch_size = input_tensors->at("input_ids").shape[0]; const size_t request_batch_size = input_tensors->at("input_ids").shape[0];
const size_t input_data_len = input_tensors->at("input_ids").shape[1]; const size_t input_data_len = input_tensors->at("input_ids").shape[1];
// freed in forward() h_total_output_lengths_ =
h_total_output_lengths_ = reinterpret_cast<uint32_t*>(malloc(request_batch_size * sizeof(uint32_t))); (uint32_t*)std::realloc((void*)h_total_output_lengths_, request_batch_size * sizeof(uint32_t));
std::unordered_map<std::string, ft::Tensor> ft_input_tensors = std::unordered_map<std::string, ft::Tensor>{ std::unordered_map<std::string, ft::Tensor> ft_input_tensors = std::unordered_map<std::string, ft::Tensor>{
{"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)}, {"input_ids", as_GPU_tensor(input_tensors->at("input_ids"), d_input_ids_)},
...@@ -251,11 +251,6 @@ LlamaTritonModelInstance<T>::forward(std::shared_ptr<std::unordered_map<std::str ...@@ -251,11 +251,6 @@ LlamaTritonModelInstance<T>::forward(std::shared_ptr<std::unordered_map<std::str
output_tensors.insert({"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}}); output_tensors.insert({"error_message", ft::Tensor{ft::MEMORY_CPU, ft::TYPE_BYTES, {1}, &h_exception_}});
} }
if (h_total_output_lengths_ != nullptr) {
free(h_total_output_lengths_);
h_total_output_lengths_ = nullptr;
}
return convert_outputs(output_tensors); return convert_outputs(output_tensors);
} }
...@@ -293,6 +288,7 @@ void LlamaTritonModelInstance<T>::freeBuffer() ...@@ -293,6 +288,7 @@ void LlamaTritonModelInstance<T>::freeBuffer()
allocator_->free((void**)(&d_sequence_lengths_)); allocator_->free((void**)(&d_sequence_lengths_));
allocator_->free((void**)(&d_output_log_probs_)); allocator_->free((void**)(&d_output_log_probs_));
allocator_->free((void**)(&d_cum_log_probs_)); allocator_->free((void**)(&d_cum_log_probs_));
std::free(h_total_output_lengths_);
} }
template struct LlamaTritonModelInstance<float>; template struct LlamaTritonModelInstance<float>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment