Unverified Commit d44a8bfe authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

Reduce gil switching (#407)

* reduce gil switching

* ffi lock func

* remove unused

* remove unused

* remove unused
parent 2dec28ae
...@@ -899,8 +899,9 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_ ...@@ -899,8 +899,9 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
if (context_logits_buf_ == nullptr) { if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true); NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_); context_logits_buf_ =
const auto tp = llama_->tensor_para_.world_size_; (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) { if (tp > 1) {
FT_CHECK(llama_->vocab_size_padded_ % tp == 0); FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
const auto local_vocab_size = llama_->vocab_size_padded_ / tp; const auto local_vocab_size = llama_->vocab_size_padded_ / tp;
...@@ -938,12 +939,18 @@ void LlamaBatch<T>::finish() ...@@ -938,12 +939,18 @@ void LlamaBatch<T>::finish()
check_cuda_error(cudaStreamSynchronize(stream_)); check_cuda_error(cudaStreamSynchronize(stream_));
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(1);
}
for (int i = 0; i < batch_size_; ++i) { for (int i = 0; i < batch_size_; ++i) {
FT_CHECK(requests_[i] != nullptr); FT_CHECK(requests_[i] != nullptr);
if (requests_[i]->stream_cb && rank_ == 0) { if (requests_[i]->stream_cb && rank_ == 0) {
requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get()); requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
} }
} }
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(0);
}
if (debug_ && rank_ == 0) { if (debug_ && rank_ == 0) {
std::stringstream ss; std::stringstream ss;
......
...@@ -34,6 +34,8 @@ ...@@ -34,6 +34,8 @@
#include "src/turbomind/utils/nccl_utils.h" #include "src/turbomind/utils/nccl_utils.h"
#include <unordered_map> #include <unordered_map>
using ffi_api_lock_ctrl_t = std::function<void(int)>;
namespace turbomind { namespace turbomind {
template<typename T> template<typename T>
...@@ -91,6 +93,11 @@ public: ...@@ -91,6 +93,11 @@ public:
return vocab_size_; return vocab_size_;
} }
void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}
private: private:
friend class Batch; friend class Batch;
...@@ -188,6 +195,8 @@ private: ...@@ -188,6 +195,8 @@ private:
std::shared_ptr<SharedState> shared_state_; std::shared_ptr<SharedState> shared_state_;
std::thread internal_thread_; std::thread internal_thread_;
ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
}; };
} // namespace turbomind } // namespace turbomind
...@@ -344,13 +344,25 @@ PYBIND11_MODULE(_turbomind, m) ...@@ -344,13 +344,25 @@ PYBIND11_MODULE(_turbomind, m)
size_t pipeline_para_size, size_t pipeline_para_size,
int enable_custom_all_reduce, int enable_custom_all_reduce,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> { std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
auto gil_control = [state = PyGILState_STATE{}](int op) mutable {
if (op) {
state = PyGILState_Ensure();
}
else {
PyGILState_Release(state);
}
};
if (data_type == "half" || data_type == "fp16" || data_type == "int4") { if (data_type == "half" || data_type == "fp16" || data_type == "int4") {
return std::make_shared<LlamaTritonModel<half>>( auto model = std::make_shared<LlamaTritonModel<half>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
} }
else { else {
return std::make_shared<LlamaTritonModel<float>>( auto model = std::make_shared<LlamaTritonModel<float>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir); tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
} }
}, },
"model_dir"_a, "model_dir"_a,
......
...@@ -276,6 +276,7 @@ LlamaTritonModel<T>::createModelInstance(int ...@@ -276,6 +276,7 @@ LlamaTritonModel<T>::createModelInstance(int
instance = shared_instances_[device_id].lock(); instance = shared_instances_[device_id].lock();
if (!instance) { if (!instance) {
instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm); instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm);
instance->llm->setFfiLock(ffi_lock_);
shared_instances_[device_id] = instance; shared_instances_[device_id] = instance;
} }
} }
......
...@@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel { ...@@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel {
void handleMissingParams(); void handleMissingParams();
void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}
std::string toString() override; std::string toString() override;
int getTensorParaSize() override; int getTensorParaSize() override;
int getPipelineParaSize() override; int getPipelineParaSize() override;
...@@ -112,4 +117,6 @@ private: ...@@ -112,4 +117,6 @@ private:
std::string model_name_; std::string model_name_;
std::string model_dir_; std::string model_dir_;
ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
}; };
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment