Unverified Commit d44a8bfe authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

Reduce gil switching (#407)

* reduce gil switching

* ffi lock func

* remove unused

* remove unused

* remove unused
parent 2dec28ae
......@@ -899,7 +899,8 @@ void LlamaBatch<T>::outputContextLogits(T* context_decoder_
if (context_logits_buf_ == nullptr) {
NcclGuard guard(llama_->tensor_para_, stream_, true);
context_logits_buf_ = (float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
context_logits_buf_ =
(float*)allocator_->malloc(sizeof(float) * llama_->vocab_size_padded_ * max_context_token_num_);
const auto tp = llama_->tensor_para_.world_size_;
if (tp > 1) {
FT_CHECK(llama_->vocab_size_padded_ % tp == 0);
......@@ -938,12 +939,18 @@ void LlamaBatch<T>::finish()
check_cuda_error(cudaStreamSynchronize(stream_));
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(1);
}
for (int i = 0; i < batch_size_; ++i) {
FT_CHECK(requests_[i] != nullptr);
if (requests_[i]->stream_cb && rank_ == 0) {
requests_[i]->stream_cb(&requests_[i]->outputs[rank_].get());
}
}
if (rank_ == 0 && llama_->ffi_lock_) {
llama_->ffi_lock_(0);
}
if (debug_ && rank_ == 0) {
std::stringstream ss;
......
......@@ -34,6 +34,8 @@
#include "src/turbomind/utils/nccl_utils.h"
#include <unordered_map>
using ffi_api_lock_ctrl_t = std::function<void(int)>;
namespace turbomind {
template<typename T>
......@@ -91,6 +93,11 @@ public:
return vocab_size_;
}
void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}
private:
friend class Batch;
......@@ -188,6 +195,8 @@ private:
std::shared_ptr<SharedState> shared_state_;
std::thread internal_thread_;
ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};
} // namespace turbomind
......@@ -344,13 +344,25 @@ PYBIND11_MODULE(_turbomind, m)
size_t pipeline_para_size,
int enable_custom_all_reduce,
std::string data_type) -> std::shared_ptr<AbstractTransformerModel> {
auto gil_control = [state = PyGILState_STATE{}](int op) mutable {
if (op) {
state = PyGILState_Ensure();
}
else {
PyGILState_Release(state);
}
};
if (data_type == "half" || data_type == "fp16" || data_type == "int4") {
return std::make_shared<LlamaTritonModel<half>>(
auto model = std::make_shared<LlamaTritonModel<half>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
else {
return std::make_shared<LlamaTritonModel<float>>(
auto model = std::make_shared<LlamaTritonModel<float>>(
tensor_para_size, pipeline_para_size, enable_custom_all_reduce, model_dir);
model->setFfiLock(gil_control);
return model;
}
},
"model_dir"_a,
......
......@@ -276,6 +276,7 @@ LlamaTritonModel<T>::createModelInstance(int
instance = shared_instances_[device_id].lock();
if (!instance) {
instance = createSharedModelInstance(device_id, rank, nccl_params, custom_all_reduce_comm);
instance->llm->setFfiLock(ffi_lock_);
shared_instances_[device_id] = instance;
}
}
......
......@@ -63,6 +63,11 @@ struct LlamaTritonModel: public AbstractTransformerModel {
void handleMissingParams();
void setFfiLock(ffi_api_lock_ctrl_t func)
{
ffi_lock_ = func;
}
std::string toString() override;
int getTensorParaSize() override;
int getPipelineParaSize() override;
......@@ -112,4 +117,6 @@ private:
std::string model_name_;
std::string model_dir_;
ffi_api_lock_ctrl_t ffi_lock_ = nullptr;
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment