Commit 0a21fff9 authored by xiabo's avatar xiabo
Browse files

Adapt to 0.1.0

parent 9484fd1c
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_lifecycle.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "repo_agent.h"
#include "triton/common/logging.h"
#include "triton/common/thread_pool.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace triton { namespace core {
const std::string&
ModelReadyStateString(ModelReadyState state)
{
switch (state) {
case ModelReadyState::UNKNOWN: {
static std::string m("UNKNOWN");
return m;
}
case ModelReadyState::READY: {
static std::string m("READY");
return m;
}
case ModelReadyState::UNAVAILABLE: {
static std::string m("UNAVAILABLE");
return m;
}
case ModelReadyState::LOADING: {
static std::string m("LOADING");
return m;
}
case ModelReadyState::UNLOADING: {
static std::string m("UNLOADING");
return m;
}
}
static std::string m("<unknown>");
return m;
}
namespace {
Status
VersionsToLoad(
const std::string model_path, const std::string& name,
const inference::ModelConfig& model_config, std::set<int64_t>* versions)
{
versions->clear();
// Get integral number of the version directory
std::set<std::string> subdirs;
RETURN_IF_ERROR(GetDirectorySubdirs(model_path, &subdirs));
std::set<int64_t, std::greater<int64_t>> existing_versions;
for (const auto& subdir : subdirs) {
if (subdir == kWarmupDataFolder || subdir == kInitialStateFolder) {
continue;
}
if ((subdir.length() > 1) && (subdir.front() == '0')) {
LOG_WARNING << "ignore version directory '" << subdir
<< "' which contains leading zeros in its directory name";
continue;
}
try {
int64_t version = std::stoll(subdir);
existing_versions.insert(version);
}
catch (const std::invalid_argument& ia) {
LOG_WARNING << "ignore version directory '" << subdir
<< "' which fails to convert to integral number";
}
}
if (model_config.version_policy().has_specific()) {
for (const auto& v : model_config.version_policy().specific().versions()) {
// Only load the specific versions that are presented in model directory
bool version_not_exist = existing_versions.insert(v).second;
if (!version_not_exist) {
versions->emplace(v);
} else {
LOG_ERROR << "version " << v << " is specified for model '" << name
<< "', but the version directory is not present";
}
}
} else {
if (model_config.version_policy().has_latest()) {
// std::set is sorted with std::greater
for (const auto& v : existing_versions) {
if (versions->size() >=
model_config.version_policy().latest().num_versions()) {
break;
}
versions->emplace(v);
}
} else {
// all
versions->insert(existing_versions.begin(), existing_versions.end());
}
}
return Status::Success;
}
// Use smart pointer with custom deleter so that model state will be updated
// to UNAVAILABLE if all smart pointer copies are out of scope
struct ModelDeleter {
ModelDeleter(std::function<void()> OnDestroyModel)
: OnDestroyModel_(std::move(OnDestroyModel))
{
}
void operator()(Model* model)
{
// The actual model object must be destroyed in a different
// thread. This thread could have a callstack that includes the
// model itself because this deleter could be triggered by
// a request release or response send in the model. Following
// delete will lead to the model destructor which may wait on this
// same thread... so deadlock if we don't use a different thread
// here.
std::function<void()> destroy_fn = OnDestroyModel_;
std::thread dthd([model, destroy_fn]() {
delete model;
destroy_fn();
});
dthd.detach();
}
// Use to inform the ModelLifeCycle that the model handle is destroyed
std::function<void()> OnDestroyModel_;
};
} // namespace
Status
ModelLifeCycle::Create(
InferenceServer* server, const ModelLifeCycleOptions& options,
std::unique_ptr<ModelLifeCycle>* life_cycle)
{
std::unique_ptr<ModelLifeCycle> local_life_cycle(
new ModelLifeCycle(server, options));
*life_cycle = std::move(local_life_cycle);
return Status::Success;
}
const ModelStateMap
ModelLifeCycle::LiveModelStates(bool strict_readiness)
{
LOG_VERBOSE(2) << "LiveModelStates()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
ModelStateMap live_model_states;
for (auto& model_version : map_) {
bool live = false;
VersionStateMap version_map;
for (auto& version_model : model_version.second) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (strict_readiness &&
version_model.second->state_ != ModelReadyState::READY) {
continue;
}
// At least one version is live (ready / loading / unloading)
if ((version_model.second->state_ != ModelReadyState::UNKNOWN) &&
(version_model.second->state_ != ModelReadyState::UNAVAILABLE)) {
live = true;
version_map[version_model.first] = std::make_pair(
version_model.second->state_, version_model.second->state_reason_);
}
}
if (live) {
live_model_states[model_version.first] = std::move(version_map);
}
}
return live_model_states;
}
Status
ModelLifeCycle::StopAllModels()
{
LOG_VERBOSE(2) << "StopAllModels()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
for (auto& model_version : map_) {
for (auto& version_model : model_version.second) {
if (version_model.second != nullptr) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (version_model.second->model_ != nullptr) {
version_model.second->model_->Stop();
}
}
}
}
return Status::Success;
}
const std::set<std::tuple<std::string, int64_t, size_t>>
ModelLifeCycle::InflightStatus()
{
LOG_VERBOSE(2) << "InflightStatus()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
std::set<std::tuple<std::string, int64_t, size_t>> inflight_status;
for (auto& model_version : map_) {
for (auto& version_model : model_version.second) {
if (version_model.second != nullptr) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (version_model.second->model_ != nullptr) {
const auto cnt =
version_model.second->model_->InflightInferenceCount();
if (cnt != 0) {
inflight_status.emplace(
model_version.first, version_model.first, cnt);
}
}
}
}
}
return inflight_status;
}
const ModelStateMap
ModelLifeCycle::ModelStates()
{
LOG_VERBOSE(2) << "ModelStates()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
ModelStateMap model_states;
for (auto& model_version : map_) {
VersionStateMap version_map;
for (auto& version_model : model_version.second) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
version_map[version_model.first] = std::make_pair(
version_model.second->state_, version_model.second->state_reason_);
}
model_states[model_version.first] = std::move(version_map);
}
return model_states;
}
const VersionStateMap
ModelLifeCycle::VersionStates(const std::string& model_name)
{
LOG_VERBOSE(2) << "VersionStates() '" << model_name << "'";
std::lock_guard<std::mutex> map_lock(map_mtx_);
VersionStateMap version_map;
auto mit = map_.find(model_name);
if (mit != map_.end()) {
for (auto& version_model : mit->second) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
version_map[version_model.first] = std::make_pair(
version_model.second->state_, version_model.second->state_reason_);
}
}
return version_map;
}
Status
ModelLifeCycle::ModelState(
const std::string& model_name, const int64_t model_version,
ModelReadyState* state)
{
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto mit = map_.find(model_name);
if (mit != map_.end()) {
auto vit = mit->second.find(model_version);
if (vit != mit->second.end()) {
std::lock_guard<std::mutex> lock(vit->second->mtx_);
*state = vit->second->state_;
return Status::Success;
}
}
return Status(
Status::Code::NOT_FOUND, "model '" + model_name + "', version " +
std::to_string(model_version) +
" is not found");
}
Status
ModelLifeCycle::GetModel(
const std::string& model_name, const int64_t version,
std::shared_ptr<Model>* model)
{
LOG_VERBOSE(2) << "GetModel() '" << model_name << "' version " << version;
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto mit = map_.find(model_name);
if (mit == map_.end()) {
return Status(Status::Code::NOT_FOUND, "'" + model_name + "' is not found");
}
auto vit = mit->second.find(version);
if (vit == mit->second.end()) {
if (version != -1) {
return Status(
Status::Code::NOT_FOUND, "'" + model_name + "' version " +
std::to_string(version) +
" is not found");
}
// The case where the request is asking for latest version
int64_t latest = -1;
for (auto& version_model : mit->second) {
if (version_model.first > latest) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (version_model.second->state_ == ModelReadyState::READY) {
latest = version_model.first;
// Tedious, but have to set handle for any "latest" version
// at the moment to avoid edge case like the following:
// "versions : 1 3 2", version 3 is latest but is requested
// to be unloaded when the iterator is examining version 2,
// then 'model' will ensure version 3 is still valid
*model = version_model.second->model_;
}
}
}
if (latest == -1) {
return Status(
Status::Code::NOT_FOUND,
"'" + model_name + "' has no available versions");
}
} else {
std::lock_guard<std::mutex> lock(vit->second->mtx_);
if (vit->second->state_ == ModelReadyState::READY) {
*model = vit->second->model_;
} else {
return Status(
Status::Code::UNAVAILABLE, "'" + model_name + "' version " +
std::to_string(version) +
" is not at ready state");
}
}
return Status::Success;
}
Status
ModelLifeCycle::AsyncUnload(const std::string& model_name)
{
LOG_VERBOSE(2) << "AsyncUnload() '" << model_name << "'";
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto it = map_.find(model_name);
if (it == map_.end()) {
return Status(
Status::Code::INVALID_ARG, "Model to be unloaded has not been served");
}
// Get the existing agent models and notify the unload action
const uint64_t now_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
for (auto& version : it->second) {
auto& model_info = version.second;
std::lock_guard<std::mutex> lock(model_info->mtx_);
model_info->last_update_ns_ = now_ns;
// Unload serving model, for model that is in LOADING state,
// the updated timestamp will be recognized that there is newer update
// on the model info and the load should be aborted
if (model_info->state_ == ModelReadyState::READY) {
if (model_info->agent_model_list_ != nullptr) {
// Only log the error because the model should be unloaded regardless
auto status = model_info->agent_model_list_->InvokeAgentModels(
TRITONREPOAGENT_ACTION_UNLOAD);
if (!status.IsOk()) {
LOG_ERROR
<< "Agent model returns error on TRITONREPOAGENT_ACTION_UNLOAD: "
<< status.AsString();
}
}
// unload
model_info->Release();
}
}
return Status::Success;
}
Status
ModelLifeCycle::AsyncLoad(
const std::string& model_name, const std::string& model_path,
const inference::ModelConfig& model_config, const bool is_config_provided,
const std::shared_ptr<TritonRepoAgentModelList>& agent_model_list,
std::function<void(Status)>&& OnComplete)
{
LOG_VERBOSE(2) << "AsyncLoad() '" << model_name << "'";
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto it = map_.find(model_name);
if (it == map_.end()) {
it = map_.emplace(std::make_pair(model_name, VersionMap())).first;
}
std::set<int64_t> versions;
RETURN_IF_ERROR(
VersionsToLoad(model_path, model_name, model_config, &versions));
if (versions.empty()) {
return Status(
Status::Code::INVALID_ARG,
"at least one version must be available under the version policy of "
"model '" +
model_name + "'");
}
const uint64_t now_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
std::shared_ptr<LoadTracker> load_tracker(
new LoadTracker(versions.size(), now_ns));
for (const auto& version : versions) {
std::unique_ptr<ModelInfo> linfo(
new ModelInfo(model_path, model_config, now_ns));
ModelInfo* model_info = linfo.get();
LOG_INFO << "loading: " << model_name << ":" << version;
model_info->state_ = ModelReadyState::LOADING;
model_info->state_reason_.clear();
model_info->agent_model_list_ = agent_model_list;
auto res = it->second.emplace(
std::make_pair(version, std::unique_ptr<ModelInfo>()));
if (res.second) {
res.first->second = std::move(linfo);
} else {
// There is already a record of this model version. Check if the version
// model is being served, if so, the re-load of the version
// should be performed in background to avoid version downtime.
// Otherwise, swap and monitor state for newly loading model.
auto& serving_model = res.first->second;
std::lock_guard<std::mutex> lock(serving_model->mtx_);
if (serving_model->state_ == ModelReadyState::READY) {
background_models_[(uintptr_t)model_info] = std::move(linfo);
} else {
// swap the monitoring model info
serving_model.swap(linfo);
// further check the state, put to 'background_models_' to keep
// the object valid if the model is LOADING / UNLOADING, because
// the model info will be accessed by a different thread once the
// operation is completed
if ((linfo->state_ == ModelReadyState::LOADING) ||
(linfo->state_ == ModelReadyState::UNLOADING)) {
ModelInfo* key = linfo.get();
background_models_[(uintptr_t)key] = std::move(linfo);
}
}
}
// Load model asynchronously via thread pool
load_pool_->Enqueue([this, model_name, version, model_info, OnComplete,
load_tracker, is_config_provided]() {
CreateModel(model_name, version, model_info, is_config_provided);
OnLoadComplete(model_name, version, model_info, OnComplete, load_tracker);
});
}
return Status::Success;
}
void
ModelLifeCycle::CreateModel(
const std::string& model_name, const int64_t version, ModelInfo* model_info,
const bool is_config_provided)
{
LOG_VERBOSE(2) << "CreateModel() '" << model_name << "' version " << version;
const auto& model_config = model_info->model_config_;
// Create model
Status status;
std::unique_ptr<Model> is;
// If 'backend' is specified in the config then use the new triton
// backend.
if (!model_config.backend().empty()) {
std::unique_ptr<TritonModel> model;
status = TritonModel::Create(
server_, model_info->model_path_, cmdline_config_map_, host_policy_map_,
model_name, version, model_config, is_config_provided, &model);
is.reset(model.release());
} else {
#ifdef TRITON_ENABLE_ENSEMBLE
if (model_info->is_ensemble_) {
status = EnsembleModel::Create(
server_, model_info->model_path_, version, model_config,
is_config_provided, min_compute_capability_, &is);
// Complete label provider with label information from involved models
// Must be done here because involved models may not be able to
// obtained from server because this may happen during server
// initialization.
if (status.IsOk()) {
std::set<std::string> no_label_outputs;
const auto& label_provider = is->GetLabelProvider();
for (const auto& output : model_config.output()) {
if (label_provider->GetLabel(output.name(), 0).empty()) {
no_label_outputs.emplace(output.name());
}
}
for (const auto& element : model_config.ensemble_scheduling().step()) {
for (const auto& pair : element.output_map()) {
// Found model that produce one of the missing output
if (no_label_outputs.find(pair.second) != no_label_outputs.end()) {
std::shared_ptr<Model> model;
// Safe to obtain model because the ensemble can't be loaded
// until the involved models are ready
GetModel(element.model_name(), element.model_version(), &model);
label_provider->AddLabels(
pair.second,
model->GetLabelProvider()->GetLabels(pair.first));
}
}
}
}
} else
#endif // TRITON_ENABLE_ENSEMBLE
{
status = Status(
Status::Code::INVALID_ARG,
"unknown platform '" + model_config.platform() + "'");
}
}
std::lock_guard<std::mutex> lock(model_info->mtx_);
if (status.IsOk()) {
// [FIXME] better way to manage agent model lifecycle
// Let the deleter also holds a shared pointer copy of agent model list,
// because the reference in ModelInfo can be cleared before the Model object
// is destroyed, and we want agent model to be valid for receiving
// UNLOAD_COMPLETE signal (see ~TritonRepoAgentModelList for detail)
auto agent_model_list = model_info->agent_model_list_;
model_info->model_.reset(
is.release(), ModelDeleter([this, model_name, version, model_info,
agent_model_list]() mutable {
LOG_VERBOSE(2) << "OnDestroy callback() '" << model_name
<< "' version " << version;
LOG_INFO << "successfully unloaded '" << model_name << "' version "
<< version;
// Update model state as it is fully unloaded
{
std::lock_guard<std::mutex> lock(model_info->mtx_);
model_info->state_ = ModelReadyState::UNAVAILABLE;
model_info->state_reason_ = "unloaded";
}
// Check if the model info is in background, if so, remove from the
// map
std::lock_guard<std::mutex> lk(this->map_mtx_);
auto it = this->background_models_.find((uintptr_t)model_info);
if (it != this->background_models_.end()) {
this->background_models_.erase(it);
}
}));
} else {
LOG_ERROR << "failed to load '" << model_name << "' version " << version
<< ": " << status.AsString();
model_info->state_ = ModelReadyState::UNAVAILABLE;
model_info->state_reason_ = status.AsString();
}
}
void
ModelLifeCycle::OnLoadComplete(
const std::string& model_name, const int64_t version, ModelInfo* model_info,
std::function<void(Status)> OnComplete,
std::shared_ptr<LoadTracker> load_tracker)
{
std::lock_guard<std::mutex> tracker_lock(load_tracker->mtx_);
++load_tracker->completed_version_cnt_;
load_tracker->load_set_[version] = model_info;
// Version will not be marked ready until all versions are
// ready, this simplify the unloading when one version fails to load as
// all other versions won't have inflight requests
if (model_info->state_ != ModelReadyState::LOADING) {
load_tracker->load_failed_ = true;
load_tracker->reason_ +=
("version " + std::to_string(version) + " is at " +
ModelReadyStateString(model_info->state_) +
" state: " + model_info->state_reason_ + ";");
}
// Check if all versions are completed and finish the load
if (load_tracker->completed_version_cnt_ ==
load_tracker->affected_version_cnt_) {
// hold 'map_mtx_' as there will be change onto the model info map
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto it = map_.find(model_name);
// Check if the load is the latest frontground action on the model
for (const auto& version_info : it->second) {
if (version_info.second->last_update_ns_ >
load_tracker->last_update_ns_) {
load_tracker->load_failed_ = true;
load_tracker->reason_ =
"Newer operation has been applied to the model lifecycle, current "
"load operation is out-dated.";
break;
}
}
if (load_tracker->load_failed_) {
// Move agent list out of ModelInfo as it needs to be invoked
// after all ModelInfos are reset
std::shared_ptr<TritonRepoAgentModelList> lagent_list;
if (model_info->agent_model_list_) {
lagent_list = std::move(model_info->agent_model_list_);
}
// If any of the versions fails to load, abort the load and unload
// all newly loaded versions
for (auto& loaded : load_tracker->load_set_) {
// Unload directly, the object is being managed either in frontground
// or background
std::lock_guard<std::mutex> lock(loaded.second->mtx_);
if (loaded.second->model_ != nullptr) {
loaded.second->Release();
}
}
if (lagent_list) {
auto status =
lagent_list->InvokeAgentModels(TRITONREPOAGENT_ACTION_LOAD_FAIL);
if (!status.IsOk()) {
LOG_ERROR << "Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_FAIL: "
<< status.AsString();
}
}
} else {
// Unload any previous loaded versions that are still available
for (auto& version_info : it->second) {
auto& mi = version_info.second;
std::lock_guard<std::mutex> info_lk(mi->mtx_);
if ((mi->state_ == ModelReadyState::READY) &&
(mi->last_update_ns_ < load_tracker->last_update_ns_)) {
if (mi->agent_model_list_ != nullptr) {
auto status = mi->agent_model_list_->InvokeAgentModels(
TRITONREPOAGENT_ACTION_UNLOAD);
if (!status.IsOk()) {
LOG_ERROR << "Agent model returns error on "
"TRITONREPOAGENT_ACTION_UNLOAD: "
<< status.AsString();
}
}
mi->Release();
}
}
// Mark current versions ready and track info in foreground
for (auto& loaded : load_tracker->load_set_) {
std::lock_guard<std::mutex> curr_info_lk(loaded.second->mtx_);
loaded.second->state_ = ModelReadyState::READY;
model_info->state_reason_.clear();
LOG_INFO << "successfully loaded '" << model_name << "' version "
<< version;
auto bit = background_models_.find((uintptr_t)loaded.second);
// Check if the version model is loaded in background, if so,
// replace and unload the current serving version
if (bit != background_models_.end()) {
auto vit = it->second.find(loaded.first);
// Need to lock the previous model info for in case the model is
// loading / unloading, this ensure the model state is consistent
// even when the load / unload is completed.
std::lock_guard<std::mutex> prev_info_lk(vit->second->mtx_);
// swap previous info into local unique pointer
auto linfo = std::move(bit->second);
vit->second.swap(linfo);
background_models_.erase(bit);
// if previous info is under change, put into 'background_models_'
if ((linfo->state_ == ModelReadyState::LOADING) ||
(linfo->state_ == ModelReadyState::UNLOADING)) {
ModelInfo* key = linfo.get();
background_models_[(uintptr_t)key] = std::move(linfo);
}
}
}
if (model_info->agent_model_list_) {
auto status = model_info->agent_model_list_->InvokeAgentModels(
TRITONREPOAGENT_ACTION_LOAD_COMPLETE);
if (!status.IsOk()) {
LOG_ERROR << "Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_COMPLETE: "
<< status.AsString();
}
}
}
if (OnComplete != nullptr) {
OnComplete(
load_tracker->load_failed_
? Status(Status::Code::INVALID_ARG, load_tracker->reason_)
: Status::Success);
}
}
}
}} // namespace triton::core
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "repo_agent.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "triton/common/thread_pool.h"
namespace triton { namespace core {
struct ModelLifeCycleOptions {
explicit ModelLifeCycleOptions(
const double min_compute_capability,
const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map,
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map,
const unsigned int model_load_thread_count)
: min_compute_capability_(min_compute_capability),
backend_cmdline_config_map_(backend_cmdline_config_map),
host_policy_map_(host_policy_map),
model_load_thread_count_(model_load_thread_count)
{
}
// The minimum supported CUDA compute capability.
const double min_compute_capability_;
// The backend configuration settings specified on the command-line
const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map_;
// The host policy setting used when loading models.
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map_;
// Number of the threads to use for concurrently loading models
const unsigned int model_load_thread_count_;
};
/// Readiness status for models.
enum class ModelReadyState {
// The model is in an unknown state. The model is not available for
// inferencing.
UNKNOWN,
// The model is ready and available for inferencing.
READY,
// The model is unavailable, indicating that the model failed to
// load or has been implicitly or explicitly unloaded. The model is
// not available for inferencing.
UNAVAILABLE,
// The model is being loaded by the inference server. The model is
// not available for inferencing.
LOADING,
// The model is being unloaded by the inference server. The model is
// not available for inferencing.
UNLOADING
};
/// Get the string representation for a ModelReadyState
const std::string& ModelReadyStateString(ModelReadyState state);
using VersionStateMap =
std::map<int64_t, std::pair<ModelReadyState, std::string>>;
using ModelStateMap = std::map<std::string, VersionStateMap>;
// Helper class to manage the lifecycle of a list of associated agent models
class TritonRepoAgentModelList {
public:
TritonRepoAgentModelList()
: last_action_type_(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE){};
~TritonRepoAgentModelList()
{
// Using destructor to finish the unload lifecycle without
// explicitly managing the last step in ModelLifecycle.
if (last_action_type_ == TRITONREPOAGENT_ACTION_UNLOAD) {
InvokeAgentModels(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE);
}
}
Status AddAgentModel(std::unique_ptr<TritonRepoAgentModel>&& agent_model)
{
agent_models_.emplace_back(std::move(agent_model));
return Status::Success;
}
size_t Size() { return agent_models_.size(); }
TritonRepoAgentModel* Back() { return agent_models_.back().get(); }
Status InvokeAgentModels(const TRITONREPOAGENT_ActionType action_type)
{
// Special handling for the current model lifecycle implementation,
// the repo agent may be asked to perform UNLOAD action multiple times,
// and the requests after the first should be ignored.
const bool first_unload =
(action_type == TRITONREPOAGENT_ACTION_UNLOAD) &&
(last_action_type_ != TRITONREPOAGENT_ACTION_UNLOAD);
if (!first_unload) {
return Status::Success;
}
last_action_type_ = action_type;
switch (action_type) {
case TRITONREPOAGENT_ACTION_LOAD:
case TRITONREPOAGENT_ACTION_UNLOAD: {
for (size_t idx = 0; idx < agent_models_.size(); ++idx) {
RETURN_IF_ERROR(agent_models_[idx]->InvokeAgent(action_type));
}
break;
}
case TRITONREPOAGENT_ACTION_LOAD_COMPLETE:
case TRITONREPOAGENT_ACTION_LOAD_FAIL:
case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: {
// reverse order
for (size_t one_pass_idx = agent_models_.size(); one_pass_idx > 0;
--one_pass_idx) {
RETURN_IF_ERROR(
agent_models_[one_pass_idx - 1]->InvokeAgent(action_type));
}
break;
}
}
return Status::Success;
}
private:
DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentModelList);
std::vector<std::unique_ptr<TritonRepoAgentModel>> agent_models_;
TRITONREPOAGENT_ActionType last_action_type_;
};
class InferenceServer;
class Model;
class ModelLifeCycle {
public:
static Status Create(
InferenceServer* server, const ModelLifeCycleOptions& options,
std::unique_ptr<ModelLifeCycle>* life_cycle);
~ModelLifeCycle()
{
// Explicitly clean up thread pool first to clean up any pending callbacks
// that may modify model lifecycle members
load_pool_.reset();
map_.clear();
}
// Start loading model with specified versions asynchronously.
// All versions that are being served will be unloaded only after
// the load is finished sucessfully.
Status AsyncLoad(
const std::string& model_name, const std::string& model_path,
const inference::ModelConfig& model_config, const bool is_config_provided,
const std::shared_ptr<TritonRepoAgentModelList>& agent_model_list,
std::function<void(Status)>&& OnComplete);
// Unload model asynchronously.
Status AsyncUnload(const std::string& model_name);
// Get specified version of the model. Latest ready version will
// be retrieved if 'version' is -1. Return error if the version specified is
// not found or it is not ready.
Status GetModel(
const std::string& model_name, const int64_t version,
std::shared_ptr<Model>* model);
// Get the ModelStateMap representation of the live models. A model is
// live if at least one of the versions is not unknown nor unavailable.
// If 'strict_readiness' is true, a model is only live if
// at least one of the versions is ready.
const ModelStateMap LiveModelStates(bool strict_readiness = false);
// Get the ModelStateMap representation of the models.
const ModelStateMap ModelStates();
// Get the VersionStateMap representation of the specified model.
const VersionStateMap VersionStates(const std::string& model_name);
// Get the state of a specific model version.
Status ModelState(
const std::string& model_name, const int64_t model_version,
ModelReadyState* state);
// Instruct the model to stop accepting new inference requests.
Status StopAllModels();
// Return the number of in-flight inference if any, model versions
// that don't have in-flight inferences will not be included.
const std::set<std::tuple<std::string, int64_t, size_t>> InflightStatus();
private:
struct ModelInfo {
ModelInfo(
const std::string& model_path,
const inference::ModelConfig& model_config,
const uint64_t last_update_ns)
: model_config_(model_config), model_path_(model_path),
#ifdef TRITON_ENABLE_ENSEMBLE
is_ensemble_(model_config.platform() == kEnsemblePlatform),
#else
is_ensemble_(false),
#endif // TRITON_ENABLE_ENSEMBLE
last_update_ns_(last_update_ns), state_(ModelReadyState::UNKNOWN)
{
}
// Release the flyweight in ModelInfo object, reflect as 'UNLOADING' in
// model state. Note that 'mtx_' should be acquired before invoking this
// function to prevent possible data race.
void Release()
{
state_ = ModelReadyState::UNLOADING;
state_reason_.clear();
agent_model_list_.reset();
model_.reset();
}
const inference::ModelConfig model_config_;
const std::string model_path_;
const bool is_ensemble_;
std::mutex mtx_;
uint64_t last_update_ns_;
ModelReadyState state_;
std::string state_reason_;
// flyweight
std::shared_ptr<TritonRepoAgentModelList> agent_model_list_;
std::shared_ptr<Model> model_;
};
struct LoadTracker {
LoadTracker(
const size_t affected_version_cnt, const uint64_t last_update_ns)
: last_update_ns_(last_update_ns),
affected_version_cnt_(affected_version_cnt), load_failed_(false),
completed_version_cnt_(0)
{
}
const uint64_t last_update_ns_;
const size_t affected_version_cnt_;
std::mutex mtx_;
bool load_failed_;
std::string reason_;
size_t completed_version_cnt_;
std::map<int64_t, ModelInfo*> load_set_;
};
ModelLifeCycle(InferenceServer* server, const ModelLifeCycleOptions& options)
: server_(server),
min_compute_capability_(options.min_compute_capability_),
cmdline_config_map_(options.backend_cmdline_config_map_),
host_policy_map_(options.host_policy_map_)
{
load_pool_.reset(new triton::common::ThreadPool(
std::max(1u, options.model_load_thread_count_)));
}
void CreateModel(
const std::string& model_name, const int64_t version,
ModelInfo* model_info, const bool is_config_provided);
// Callback function template for model load.
// 'OnComplete' needs to be passed by value for now as there can be
// multiple versions to be loaded and each holds a copy of
// the 'OnComplete' callback.
void OnLoadComplete(
const std::string& model_name, const int64_t version,
ModelInfo* model_info, std::function<void(Status)> OnComplete,
std::shared_ptr<LoadTracker> load_tracker);
// Mutex for 'map_' and 'background_models_'
std::mutex map_mtx_;
using VersionMap = std::map<int64_t, std::unique_ptr<ModelInfo>>;
using ModelMap = std::map<std::string, VersionMap>;
ModelMap map_;
// Models that are being loaded / unloaded in background
std::map<uintptr_t, std::unique_ptr<ModelInfo>> background_models_;
InferenceServer* server_;
const double min_compute_capability_;
const triton::common::BackendCmdlineConfigMap cmdline_config_map_;
const triton::common::HostPolicyCmdlineConfigMap host_policy_map_;
// Fixed-size thread pool to load models at specified concurrency
std::unique_ptr<triton::common::ThreadPool> load_pool_;
};
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_repository_manager.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "ensemble_utils.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace triton { namespace core {
namespace {
static std::string file_prefix = "file:";
// Internal repo agent used for model file override
class LocalizeRepoAgent : public TritonRepoAgent {
public:
LocalizeRepoAgent()
: TritonRepoAgent("ModelRepositoryManager::LocalizeRepoAgent")
{
// Callbacks below interact with TritonRepoAgentModel directly knowing that
// it is the internal implementation of TRITONREPOAGENT_AgentModel
model_action_fn_ = [](TRITONREPOAGENT_Agent* agent,
TRITONREPOAGENT_AgentModel* model,
const TRITONREPOAGENT_ActionType action_type)
-> TRITONSERVER_Error* {
auto agent_model = reinterpret_cast<TritonRepoAgentModel*>(model);
switch (action_type) {
case TRITONREPOAGENT_ACTION_LOAD: {
// localize the override files for model loading,
// as currently the model is expected to load from local directory
const char* temp_dir_cstr = nullptr;
RETURN_TRITONSERVER_ERROR_IF_ERROR(
agent_model->AcquireMutableLocation(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &temp_dir_cstr));
const std::string temp_dir = temp_dir_cstr;
const auto& files =
*reinterpret_cast<std::vector<const InferenceParameter*>*>(
agent_model->State());
bool found_config = false;
for (const auto& file : files) {
if (file->Name() == "config") {
if (file->Type() != TRITONSERVER_PARAMETER_STRING) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"Config parameter 'config' must have string type for its "
"value");
}
inference::ModelConfig config;
RETURN_TRITONSERVER_ERROR_IF_ERROR(JsonToModelConfig(
file->ValueString(), 1 /* config_version */, &config));
RETURN_TRITONSERVER_ERROR_IF_ERROR(WriteTextProto(
JoinPath({temp_dir, kModelConfigPbTxt}), config));
found_config = true;
} else if (file->Name().rfind(file_prefix, 0) == 0) {
if (file->Type() != TRITONSERVER_PARAMETER_BYTES) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("File parameter '") + file->Name() +
"' must have bytes type for its value")
.c_str());
}
// Save model file to the instructed directory
// mkdir
const std::string file_path =
JoinPath({temp_dir, file->Name().substr(file_prefix.size())});
const std::string dir = DirName(file_path);
bool dir_exist = false;
RETURN_TRITONSERVER_ERROR_IF_ERROR(FileExists(dir, &dir_exist));
if (dir_exist) {
bool is_dir = false;
RETURN_TRITONSERVER_ERROR_IF_ERROR(IsDirectory(dir, &is_dir));
if (!is_dir) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("Invalid file parameter '") + file->Name() +
"', directory has been created as a file")
.c_str());
}
} else {
RETURN_TRITONSERVER_ERROR_IF_ERROR(
MakeDirectory(dir, true /* recursive */));
}
// write
RETURN_TRITONSERVER_ERROR_IF_ERROR(WriteBinaryFile(
file_path,
reinterpret_cast<const char*>(file->ValuePointer()),
file->ValueByteSize()));
}
}
if (!found_config) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"Load parameter 'config' must be specified for model file "
"override");
}
// Commit the temporary directory
RETURN_TRITONSERVER_ERROR_IF_ERROR(agent_model->SetLocation(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM, temp_dir_cstr));
break;
}
default:
break;
}
return nullptr; // success
};
model_fini_fn_ =
[](TRITONREPOAGENT_Agent* agent,
TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* {
auto agent_model = reinterpret_cast<TritonRepoAgentModel*>(model);
RETURN_TRITONSERVER_ERROR_IF_ERROR(agent_model->DeleteMutableLocation());
return nullptr; // success
};
}
};
Status
CreateAgentModelListWithLoadAction(
const inference::ModelConfig& original_model_config,
const std::string& original_model_path,
std::shared_ptr<TritonRepoAgentModelList>* agent_model_list)
{
if (original_model_config.has_model_repository_agents()) {
// Trick to append user specified repo agent on top of internal ones
std::shared_ptr<TritonRepoAgentModelList> lagent_model_list;
if (*agent_model_list != nullptr) {
lagent_model_list = std::move(*agent_model_list);
} else {
lagent_model_list.reset(new TritonRepoAgentModelList());
}
FileSystemType filesystem_type;
RETURN_IF_ERROR(GetFileSystemType(original_model_path, &filesystem_type));
TRITONREPOAGENT_ArtifactType artifact_type =
TRITONREPOAGENT_ARTIFACT_FILESYSTEM;
if (filesystem_type != FileSystemType::LOCAL) {
artifact_type = TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM;
}
const char* location = original_model_path.c_str();
inference::ModelConfig model_config = original_model_config;
for (const auto& agent_config :
original_model_config.model_repository_agents().agents()) {
std::shared_ptr<TritonRepoAgent> agent;
RETURN_IF_ERROR(
TritonRepoAgentManager::CreateAgent(agent_config.name(), &agent));
TritonRepoAgent::Parameters agent_params;
for (const auto& parameter : agent_config.parameters()) {
agent_params.emplace_back(parameter.first, parameter.second);
}
std::unique_ptr<TritonRepoAgentModel> agent_model;
if (lagent_model_list->Size() != 0) {
lagent_model_list->Back()->Location(&artifact_type, &location);
const auto config_path = JoinPath({location, kModelConfigPbTxt});
if (!ReadTextProto(config_path, &model_config).IsOk()) {
model_config.Clear();
}
}
RETURN_IF_ERROR(TritonRepoAgentModel::Create(
artifact_type, location, model_config, agent, agent_params,
&agent_model));
RETURN_IF_ERROR(agent_model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD));
lagent_model_list->AddAgentModel(std::move(agent_model));
}
*agent_model_list = std::move(lagent_model_list);
}
return Status::Success;
}
int64_t
GetModifiedTime(const std::string& path)
{
// If there is an error in any step the fall-back default
// modification time is 0. This means that in error cases 'path'
// will show as not modified. This is the safe fall-back to avoid
// assuming a model is constantly being modified.
bool path_is_dir;
Status status = IsDirectory(path, &path_is_dir);
if (!status.IsOk()) {
LOG_ERROR << "Failed to determine modification time for '" << path
<< "': " << status.AsString();
return 0;
}
// If 'path' is a file return its mtime. Otherwise, using the modification
// time of the directory as baseline in case of file deletion
int64_t mtime = 0;
status = FileModificationTime(path, &mtime);
if (!status.IsOk()) {
LOG_ERROR << "Failed to determine modification time for '" << path
<< "': " << status.AsString();
return 0;
}
if (!path_is_dir) {
return mtime;
}
// 'path' is a directory. Return the most recent mtime of the
// contents of the directory.
std::set<std::string> contents;
status = GetDirectoryContents(path, &contents);
if (!status.IsOk()) {
LOG_ERROR << "Failed to determine modification time for '" << path
<< "': " << status.AsString();
return 0;
}
for (const auto& child : contents) {
const auto full_path = JoinPath({path, child});
mtime = std::max(mtime, GetModifiedTime(full_path));
}
return mtime;
}
// Return true if any file in the subdirectory root at 'path' has been
// modified more recently than 'last'. Return the most-recent modified
// time in 'last'.
bool
IsModified(const std::string& path, int64_t* last_ns)
{
const int64_t repo_ns = GetModifiedTime(path);
bool modified = repo_ns > *last_ns;
*last_ns = repo_ns;
return modified;
}
} // namespace
struct ModelRepositoryManager::ModelInfo {
ModelInfo(
const int64_t mtime_nsec, const int64_t prev_mtime_ns,
const std::string& model_path)
: mtime_nsec_(mtime_nsec), prev_mtime_ns_(prev_mtime_ns),
explicitly_load_(true), model_path_(model_path),
is_config_provided_(false)
{
}
ModelInfo()
: mtime_nsec_(0), prev_mtime_ns_(0), explicitly_load_(true),
is_config_provided_(false)
{
}
int64_t mtime_nsec_;
int64_t prev_mtime_ns_;
bool explicitly_load_;
inference::ModelConfig model_config_;
std::string model_path_;
// Temporary location to hold agent model list before creating the model
// the ownership must transfer to ModelLifeCycle to ensure
// the agent model life cycle is handled properly.
std::shared_ptr<TritonRepoAgentModelList> agent_model_list_;
bool is_config_provided_;
};
ModelRepositoryManager::ModelRepositoryManager(
const std::set<std::string>& repository_paths, const bool autofill,
const bool polling_enabled, const bool model_control_enabled,
const double min_compute_capability,
std::unique_ptr<ModelLifeCycle> life_cycle)
: repository_paths_(repository_paths), autofill_(autofill),
polling_enabled_(polling_enabled),
model_control_enabled_(model_control_enabled),
min_compute_capability_(min_compute_capability),
model_life_cycle_(std::move(life_cycle))
{
}
ModelRepositoryManager::~ModelRepositoryManager() {}
Status
ModelRepositoryManager::Create(
InferenceServer* server, const std::string& server_version,
const std::set<std::string>& repository_paths,
const std::set<std::string>& startup_models, const bool strict_model_config,
const bool polling_enabled, const bool model_control_enabled,
const ModelLifeCycleOptions& life_cycle_options,
std::unique_ptr<ModelRepositoryManager>* model_repository_manager)
{
// The rest only matters if repository path is valid directory
for (const auto& path : repository_paths) {
bool path_is_dir;
RETURN_IF_ERROR(IsDirectory(path, &path_is_dir));
if (!path_is_dir) {
return Status(
Status::Code::INVALID_ARG,
"repository path is not a valid directory");
}
}
if (polling_enabled && model_control_enabled) {
return Status(
Status::Code::INVALID_ARG,
"cannot enable both polling and explicit model control");
}
std::unique_ptr<ModelLifeCycle> life_cycle;
RETURN_IF_ERROR(
ModelLifeCycle::Create(server, life_cycle_options, &life_cycle));
// Not setting the smart pointer directly to simplify clean up
std::unique_ptr<ModelRepositoryManager> local_manager(
new ModelRepositoryManager(
repository_paths, !strict_model_config, polling_enabled,
model_control_enabled, life_cycle_options.min_compute_capability_,
std::move(life_cycle)));
*model_repository_manager = std::move(local_manager);
// Support loading all models on startup in explicit model control mode with
// special startup_model name "*". This does not imply support for pattern
// matching in model names.
bool load_all_models_on_startup = false;
if ((startup_models.find("*") != startup_models.end()) &&
model_control_enabled) {
if (startup_models.size() > 1) {
return Status(
Status::Code::INVALID_ARG,
"Wildcard model name '*' must be the ONLY startup model "
"if specified at all.");
}
load_all_models_on_startup = true;
}
bool all_models_polled = true;
if (!model_control_enabled || load_all_models_on_startup) {
// only error happens before model load / unload will be return
// model loading / unloading error will be printed but ignored
RETURN_IF_ERROR(
(*model_repository_manager)->PollAndUpdateInternal(&all_models_polled));
} else {
// Load each specified startup_model
std::unordered_map<std::string, std::vector<const InferenceParameter*>>
models;
for (const auto& model_name : startup_models) {
models[model_name];
}
RETURN_IF_ERROR(
(*model_repository_manager)
->LoadUnloadModels(
models, ActionType::LOAD, false, &all_models_polled));
}
if (!all_models_polled) {
return Status(Status::Code::INTERNAL, "failed to load all models");
}
// Some models may failed to be loaded after model manager is created,
// return proper error and let function caller decide whether to proceed.
for (const auto& model : (*model_repository_manager)->infos_) {
const auto version_states =
(*model_repository_manager)
->model_life_cycle_->VersionStates(model.first);
// Return general error message, detail of each model's loading state
// is logged separately.
if (version_states.empty()) {
return Status(Status::Code::INTERNAL, "failed to load all models");
}
for (const auto& state : version_states) {
if (state.second.first != ModelReadyState::READY) {
return Status(Status::Code::INTERNAL, "failed to load all models");
}
}
}
return Status::Success;
}
Status
ModelRepositoryManager::PollAndUpdate()
{
if (!polling_enabled_) {
return Status(Status::Code::UNAVAILABLE, "polling is disabled");
}
bool all_models_polled;
return PollAndUpdateInternal(&all_models_polled);
}
Status
ModelRepositoryManager::PollAndUpdateInternal(bool* all_models_polled)
{
// Serialize all operations that change model state
std::lock_guard<std::mutex> lock(poll_mu_);
std::set<std::string> added, deleted, modified, unmodified;
// We don't modify 'infos_' in place to minimize how long we need to
// hold the lock and also prevent any partial changes to do an error
// during processing.
ModelInfoMap new_infos;
// Each subdirectory of repository path is a model directory from
// which we read the model configuration.
std::unordered_map<std::string, std::vector<const InferenceParameter*>>
subdirs;
RETURN_IF_ERROR(Poll(
subdirs, &added, &deleted, &modified, &unmodified, &new_infos,
all_models_polled));
// Anything in 'infos_' that is not in "added", "modified", or
// "unmodified" is deleted.
for (const auto& pr : infos_) {
if ((added.find(pr.first) == added.end()) &&
(modified.find(pr.first) == modified.end()) &&
(unmodified.find(pr.first) == unmodified.end())) {
deleted.insert(pr.first);
}
}
// Nothing to do if no model adds, deletes or modifies.
if (added.empty() && deleted.empty() && modified.empty()) {
return Status::Success;
}
infos_.swap(new_infos);
UpdateDependencyGraph(added, deleted, modified);
for (const auto& name : deleted) {
model_life_cycle_->AsyncUnload(name);
}
// model loading / unloading error will be printed but ignored
LoadModelByDependency();
return Status::Success;
}
std::map<std::string, Status>
ModelRepositoryManager::LoadModelByDependency()
{
std::map<std::string, Status> res;
struct ModelState {
ModelState(DependencyNode* node) : node_(node), status_(Status::Success) {}
DependencyNode* node_;
Status status_;
std::promise<void> ready_;
};
NodeSet loaded_models;
auto set_pair = ModelsToLoadUnload(loaded_models);
// Loop until all model are loaded / unloaded
while ((!set_pair.first.empty()) || (!set_pair.second.empty())) {
loaded_models.clear();
// Unload invalid models first
for (auto& invalid_model : set_pair.second) {
model_life_cycle_->AsyncUnload(invalid_model->model_name_);
LOG_ERROR << invalid_model->status_.AsString();
invalid_model->loaded_versions_ = std::set<int64_t>();
loaded_models.emplace(invalid_model);
}
// load valid models and wait for load results
std::vector<std::unique_ptr<ModelState>> model_states;
for (auto& valid_model : set_pair.first) {
model_states.emplace_back(new ModelState(valid_model));
auto model_state = model_states.back().get();
const auto itr = infos_.find(valid_model->model_name_);
auto status = model_life_cycle_->AsyncLoad(
valid_model->model_name_, itr->second->model_path_,
valid_model->model_config_, itr->second->is_config_provided_,
itr->second->agent_model_list_, [model_state](Status load_status) {
model_state->status_ = load_status;
model_state->ready_.set_value();
});
if (!status.IsOk()) {
model_state->status_ = status;
model_state->ready_.set_value();
LOG_ERROR << "failed to load model '" << valid_model->model_name_
<< "': " << status.Message();
}
loaded_models.emplace(valid_model);
}
for (auto& model_state : model_states) {
model_state->ready_.get_future().wait();
res[model_state->node_->model_name_] = model_state->status_;
const auto version_state =
model_life_cycle_->VersionStates(model_state->node_->model_name_);
model_state->node_->loaded_versions_.clear();
for (const auto& vs : version_state) {
if (vs.second.first == ModelReadyState::READY) {
model_state->node_->loaded_versions_.emplace(vs.first);
}
}
// If the model failed to load, should revert the timestamp to
// ensure the next load request will attempt to load the model again
// for operation consistency.
if (!model_state->status_.IsOk()) {
auto& model_info = infos_.find(model_state->node_->model_name_)->second;
model_info->mtime_nsec_ = model_info->prev_mtime_ns_;
}
}
set_pair = ModelsToLoadUnload(loaded_models);
}
// Clear temporary stored agent model list after all loads are triggerred
for (auto& info : infos_) {
info.second->agent_model_list_.reset();
}
return res;
}
Status
ModelRepositoryManager::LoadUnloadModel(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
const ActionType type, const bool unload_dependents)
{
if (!model_control_enabled_) {
return Status(
Status::Code::UNAVAILABLE,
"explicit model load / unload is not allowed if polling is enabled");
}
if (models.size() > 1) {
return Status(
Status::Code::UNSUPPORTED,
"explicit load / unload multiple models is not currently supported");
}
// Serialize all operations that change model state
std::lock_guard<std::mutex> lock(poll_mu_);
bool polled = true;
RETURN_IF_ERROR(LoadUnloadModels(models, type, unload_dependents, &polled));
// Check if model is loaded / unloaded properly
const auto& model_name = models.begin()->first;
if (!polled) {
return Status(
Status::Code::INTERNAL, "failed to load '" + model_name +
"', failed to poll from model repository");
}
const auto version_states = model_life_cycle_->VersionStates(model_name);
if (type == ActionType::LOAD) {
if (version_states.empty()) {
return Status(
Status::Code::INTERNAL,
"failed to load '" + model_name + "', no version is available");
}
auto it = infos_.find(model_name);
if (it == infos_.end()) {
return Status(
Status::Code::INTERNAL,
"failed to load '" + model_name +
"', failed to poll from model repository");
}
} else {
std::string ready_version_str;
for (const auto& version_state : version_states) {
if (version_state.second.first == ModelReadyState::READY) {
ready_version_str += std::to_string(version_state.first);
ready_version_str += ",";
}
}
if (!ready_version_str.empty()) {
ready_version_str.pop_back();
return Status(
Status::Code::INTERNAL,
"failed to unload '" + model_name +
"', versions that are still available: " + ready_version_str);
}
}
return Status::Success;
}
Status
ModelRepositoryManager::LoadUnloadModels(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
const ActionType type, const bool unload_dependents,
bool* all_models_polled)
{
auto status = Status::Success;
*all_models_polled = true;
// Update ModelInfo related to file system accordingly
std::set<std::string> added, deleted, modified, unmodified;
{
if (type == ActionType::UNLOAD) {
for (const auto& model : models) {
deleted.insert(model.first);
}
}
// ActionType::LOAD and in model control mode
else {
std::set<std::string> checked_models;
auto current_models = models;
for (const auto& model : models) {
checked_models.emplace(model.first);
}
ModelInfoMap new_infos;
#ifdef TRITON_ENABLE_ENSEMBLE
bool first_iteration = true;
#endif // TRITON_ENABLE_ENSEMBLE
while (!current_models.empty()) {
bool polled = true;
RETURN_IF_ERROR(Poll(
current_models, &added, &deleted, &modified, &unmodified,
&new_infos, &polled));
*all_models_polled &= polled;
// More models should be polled if the polled models are ensembles
std::unordered_map<std::string, std::vector<const InferenceParameter*>>
next_models;
#ifdef TRITON_ENABLE_ENSEMBLE
for (const auto& model : current_models) {
auto it = new_infos.find(model.first);
// Some models may be marked as deleted and not in 'new_infos'
if (it != new_infos.end()) {
it->second->explicitly_load_ = first_iteration;
const auto& config = it->second->model_config_;
if (config.has_ensemble_scheduling()) {
for (const auto& step : config.ensemble_scheduling().step()) {
bool need_poll =
checked_models.emplace(step.model_name()).second;
if (need_poll) {
next_models[step.model_name()];
}
}
}
}
}
first_iteration = false;
#endif // TRITON_ENABLE_ENSEMBLE
current_models.swap(next_models);
}
// Only update the infos when all validation is completed
for (const auto& model_name : added) {
auto nitr = new_infos.find(model_name);
infos_.emplace(model_name, std::move(nitr->second));
}
for (const auto& model_name : modified) {
auto nitr = new_infos.find(model_name);
auto itr = infos_.find(model_name);
itr->second = std::move(nitr->second);
}
}
}
std::set<std::string> deleted_dependents;
// Update dependency graph and load
UpdateDependencyGraph(
added, deleted, modified,
unload_dependents ? &deleted_dependents : nullptr);
// The models are in 'deleted' either when they are asked to be unloaded or
// they are not found / are duplicated across all model repositories.
// In all cases, should unload them and remove from 'infos_' explicitly.
for (const auto& name : (unload_dependents ? deleted_dependents : deleted)) {
infos_.erase(name);
model_life_cycle_->AsyncUnload(name);
}
// load / unload the models affected, and check the load status of
// the requested models
const auto& load_status = LoadModelByDependency();
if (status.IsOk() && (type == ActionType::LOAD)) {
std::string load_error_message = "";
for (const auto& model : models) {
auto it = load_status.find(model.first);
// If 'model.first' not in load status, it means the (re-)load is not
// necessary because there is no change in the model's directory
if ((it != load_status.end()) && !it->second.IsOk()) {
load_error_message +=
("load failed for model '" + model.first +
"': " + it->second.Message() + "\n");
}
}
if (!load_error_message.empty()) {
status = Status(Status::Code::INVALID_ARG, load_error_message);
}
}
return status;
}
Status
ModelRepositoryManager::UnloadAllModels()
{
Status status;
for (const auto& name_info : infos_) {
Status unload_status = model_life_cycle_->AsyncUnload(name_info.first);
if (!unload_status.IsOk()) {
status = Status(
unload_status.ErrorCode(),
"Failed to gracefully unload models: " + unload_status.Message());
}
}
return Status::Success;
}
Status
ModelRepositoryManager::StopAllModels()
{
return model_life_cycle_->StopAllModels();
}
const std::set<std::tuple<std::string, int64_t, size_t>>
ModelRepositoryManager::InflightStatus()
{
return model_life_cycle_->InflightStatus();
}
const ModelStateMap
ModelRepositoryManager::LiveModelStates(bool strict_readiness)
{
return model_life_cycle_->LiveModelStates(strict_readiness);
}
const ModelStateMap
ModelRepositoryManager::ModelStates()
{
return model_life_cycle_->ModelStates();
}
const VersionStateMap
ModelRepositoryManager::VersionStates(const std::string& model_name)
{
return model_life_cycle_->VersionStates(model_name);
}
Status
ModelRepositoryManager::ModelState(
const std::string& model_name, const int64_t model_version,
ModelReadyState* state)
{
return model_life_cycle_->ModelState(model_name, model_version, state);
}
Status
ModelRepositoryManager::RepositoryIndex(
const bool ready_only, std::vector<ModelIndex>* index)
{
std::set<std::string> seen_models;
std::set<std::string> duplicate_models;
for (const auto& repository_path : repository_paths_) {
// For any mapped models in this repository, save the mapping
// from their subdirectory name to model name.
std::map<std::string, std::string> models_in_repo;
for (const auto& mapping_it : model_mappings_) {
if (mapping_it.second.first == repository_path) {
models_in_repo.emplace(
BaseName(mapping_it.second.second), mapping_it.first);
}
}
std::set<std::string> subdirs;
RETURN_IF_ERROR(GetDirectorySubdirs(repository_path, &subdirs));
for (const auto& subdir : subdirs) {
auto model = subdir;
auto model_it = models_in_repo.find(subdir);
if (model_it != models_in_repo.end()) {
model = model_it->second;
}
if (seen_models.find(model) != seen_models.end()) {
duplicate_models.insert(model);
}
seen_models.insert(model);
}
}
ModelStateMap states = ModelStates();
for (const auto& model : seen_models) {
// If the same model appears in multiple repostories then show it
// as unavailable since duplicate models are not allowed to load.
if (duplicate_models.find(model) != duplicate_models.end()) {
index->emplace_back(
model, -1 /* version */, ModelReadyState::UNAVAILABLE,
MODEL_READY_REASON_DUPLICATE);
continue;
}
// If there is any version/state/reason associated with the model
// then include that in the index.
auto sitr = states.find(model);
if (sitr == states.end()) {
if (!ready_only) {
index->emplace_back(model);
}
} else {
for (const auto& pr : sitr->second) {
if (!ready_only || (pr.second.first == ModelReadyState::READY)) {
index->emplace_back(
model, pr.first, pr.second.first, pr.second.second);
}
}
}
}
return Status::Success;
}
Status
ModelRepositoryManager::GetModel(
const std::string& model_name, const int64_t model_version,
std::shared_ptr<Model>* model)
{
Status status = model_life_cycle_->GetModel(model_name, model_version, model);
if (!status.IsOk()) {
model->reset();
status = Status(
status.ErrorCode(), "Request for unknown model: " + status.Message());
}
return status;
}
Status
ModelRepositoryManager::Poll(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
std::set<std::string>* added, std::set<std::string>* deleted,
std::set<std::string>* modified, std::set<std::string>* unmodified,
ModelInfoMap* updated_infos, bool* all_models_polled)
{
*all_models_polled = true;
// empty path is the special case to indicate the model should be loaded
// from override file content in 'models'.
std::map<std::string, std::string> model_to_path;
// If no model is specified, poll all models in all model repositories.
// Otherwise, only poll the specified models
if (models.empty()) {
std::set<std::string> duplicated_models;
for (const auto& repository_path : repository_paths_) {
std::set<std::string> subdirs;
Status status = GetDirectorySubdirs(repository_path, &subdirs);
if (!status.IsOk()) {
LOG_ERROR << "failed to poll model repository '" << repository_path
<< "': " << status.Message();
*all_models_polled = false;
} else {
for (const auto& subdir : subdirs) {
if (!model_to_path
.emplace(subdir, JoinPath({repository_path, subdir}))
.second) {
duplicated_models.insert(subdir);
*all_models_polled = false;
}
}
}
}
// If the model is not unique, mark as deleted to unload it
for (const auto& model : duplicated_models) {
model_to_path.erase(model);
deleted->insert(model);
LOG_ERROR << "failed to poll model '" << model
<< "': not unique across all model repositories";
}
}
// If models are specified, this is explicit model control mode.
else {
for (const auto& model : models) {
// Skip repository polling if override model files
if (ModelDirectoryOverride(model.second)) {
model_to_path.emplace(model.first, "");
continue;
}
// Check model mapping first to see if matching model to load.
bool exists = false;
auto model_it = model_mappings_.find(model.first);
if (model_it != model_mappings_.end()) {
bool exists_in_this_repo = false;
auto full_path = model_it->second.second;
Status status = FileExists(full_path, &exists_in_this_repo);
if (!status.IsOk()) {
LOG_ERROR << "failed to poll mapped path '" << full_path
<< "' for model '" << model.first
<< "': " << status.Message();
*all_models_polled = false;
}
if (exists_in_this_repo) {
model_to_path.emplace(model.first, model_it->second.second);
exists = true;
} else {
LOG_ERROR << "mapped path '" << full_path
<< "' does not exist for model '" << model.first << "'";
exists = false;
}
} else {
for (const auto repository_path : repository_paths_) {
bool exists_in_this_repo = false;
const auto full_path = JoinPath({repository_path, model.first});
Status status = FileExists(full_path, &exists_in_this_repo);
if (!status.IsOk()) {
LOG_ERROR << "failed to poll model repository '" << repository_path
<< "' for model '" << model.first
<< "': " << status.Message();
*all_models_polled = false;
} else if (exists_in_this_repo) {
// Check to make sure this directory is not mapped.
// If mapped, continue to next repository path.
bool mapped = false;
for (auto const& mapping : model_mappings_) {
if (mapping.second.second == full_path) {
mapped = true;
break;
}
}
if (mapped) {
continue;
}
auto res = model_to_path.emplace(
model.first, JoinPath({repository_path, model.first}));
if (res.second) {
exists = true;
} else {
exists = false;
model_to_path.erase(res.first);
LOG_ERROR << "failed to poll model '" << model.first
<< "': not unique across all model repositories";
break;
}
}
}
}
// For an explicitly specified model that doesn't exist, we don't mark it
// as deleted, we simply mark that we couldn't poll all models.
if (!exists) {
*all_models_polled = false;
}
}
}
// Poll each of the models. If error happens during polling the model,
// its state will fallback to the state before the polling.
for (const auto& pair : model_to_path) {
std::unique_ptr<ModelInfo> model_info;
const auto& mit = models.find(pair.first);
static std::vector<const InferenceParameter*> empty_params;
auto status = InitializeModelInfo(
pair.first, pair.second,
((mit == models.end()) ? empty_params : mit->second), &model_info);
const auto& iitr = infos_.find(pair.first);
const bool invalid_add = (!status.IsOk()) && (iitr == infos_.end());
if (!invalid_add) {
const auto& ret = updated_infos->emplace(pair.first, nullptr);
if (!ret.second) {
return Status(
Status::Code::ALREADY_EXISTS,
"unexpected model info for model '" + pair.first + "'");
}
// Classify load state and set updated info
if (model_info == nullptr) {
ret.first->second.reset(new ModelInfo(*iitr->second));
unmodified->insert(pair.first);
} else {
ret.first->second = std::move(model_info);
if (iitr != infos_.end()) {
modified->insert(pair.first);
} else {
added->insert(pair.first);
}
}
}
if (!status.IsOk()) {
LOG_ERROR << "Poll failed for model directory '" << pair.first
<< "': " << status.Message();
*all_models_polled = false;
}
}
return Status::Success;
}
bool
ModelRepositoryManager::ModelDirectoryOverride(
const std::vector<const InferenceParameter*>& model_params)
{
for (const auto& param : model_params) {
if (param->Name().rfind(file_prefix, 0) == 0) {
// param name starts with prefix if user provides override file
return true;
}
}
return false;
}
Status
ModelRepositoryManager::InitializeModelInfo(
const std::string& name, const std::string& path,
const std::vector<const InferenceParameter*>& params,
std::unique_ptr<ModelInfo>* info)
{
std::unique_ptr<ModelInfo> linfo(new ModelInfo());
linfo->model_path_ = path;
bool unmodified = false;
const auto iitr = infos_.find(name);
// Set 'prev_mtime_ns_' if there is existing ModelInfo
if (iitr != infos_.end()) {
linfo->prev_mtime_ns_ = iitr->second->mtime_nsec_;
} else {
linfo->prev_mtime_ns_ = 0;
}
// Set 'mtime_nsec_' and override 'model_path_' if current path is empty
// (file override is specified)
if (linfo->model_path_.empty()) {
// Need to localize the override files, use repo agent to manage
// the lifecycle of the localized files
std::shared_ptr<TritonRepoAgent> localize_agent(new LocalizeRepoAgent());
std::unique_ptr<TritonRepoAgentModel> localize_agent_model;
RETURN_IF_ERROR(TritonRepoAgentModel::Create(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM, "", inference::ModelConfig(),
localize_agent, {}, &localize_agent_model));
// Set agent model state so the repo agent can access the encoded files
// Using const_cast here but we are safe as the RepoAgent will not
// modify the state
localize_agent_model->SetState(
const_cast<void*>(reinterpret_cast<const void*>(&params)));
RETURN_IF_ERROR(
localize_agent_model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD));
const char* location;
TRITONREPOAGENT_ArtifactType type;
RETURN_IF_ERROR(localize_agent_model->Location(&type, &location));
// For file override, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo->mtime_nsec_ = 0;
linfo->model_path_ = location;
linfo->agent_model_list_.reset(new TritonRepoAgentModelList());
linfo->agent_model_list_->AddAgentModel(std::move(localize_agent_model));
} else {
if (iitr == infos_.end()) {
linfo->mtime_nsec_ = GetModifiedTime(std::string(linfo->model_path_));
} else {
// Check the current timestamps to determine if model actually has been
// modified
linfo->mtime_nsec_ = linfo->prev_mtime_ns_;
unmodified =
!IsModified(std::string(linfo->model_path_), &linfo->mtime_nsec_);
}
}
// Set 'model_config_'
bool parsed_config = false;
// Check if there is config override
for (const auto& override_parameter : params) {
if ((override_parameter->Name() == "config") &&
(override_parameter->Type() == TRITONSERVER_PARAMETER_STRING)) {
// When override happens, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo->mtime_nsec_ = 0;
unmodified = false;
const std::string& override_config = override_parameter->ValueString();
auto err = JsonToModelConfig(
override_config, 1 /* config_version */, &linfo->model_config_);
if (!err.IsOk()) {
return Status(
Status::Code::INVALID_ARG,
"Invalid config override: " + std::string(err.Message()));
}
parsed_config = true;
break;
} else if (override_parameter->Name().rfind(file_prefix, 0) != 0) {
return Status(
Status::Code::INVALID_ARG,
"Unrecognized load parameter '" + override_parameter->Name() +
"' with type '" +
TRITONSERVER_ParameterTypeString(override_parameter->Type()) +
"'");
}
}
// Polling model is considered unmodified by this point and can be returned
// with info == nullptr
if (unmodified) {
return Status::Success;
}
// Create the associated repo agent models when a model is to be loaded,
// this must be done before normalizing model config as agents might
// redirect to use the model config at a different location
if (!parsed_config) {
const auto config_path = JoinPath({linfo->model_path_, kModelConfigPbTxt});
bool model_config_exists = false;
RETURN_IF_ERROR(FileExists(config_path, &model_config_exists));
// model config can be missing if auto fill is set
if (autofill_ && !model_config_exists) {
linfo->model_config_.Clear();
} else {
RETURN_IF_ERROR(ReadTextProto(config_path, &linfo->model_config_));
parsed_config = true;
}
}
if (parsed_config) {
RETURN_IF_ERROR(CreateAgentModelListWithLoadAction(
linfo->model_config_, linfo->model_path_, &linfo->agent_model_list_));
if (linfo->agent_model_list_ != nullptr) {
// Get the latest repository path
const char* location;
TRITONREPOAGENT_ArtifactType artifact_type;
RETURN_IF_ERROR(linfo->agent_model_list_->Back()->Location(
&artifact_type, &location));
auto latest_path = std::string(location);
linfo->model_path_ = latest_path;
}
}
linfo->is_config_provided_ = parsed_config;
// Try to automatically generate missing parts of the model
// configuration (autofill) that don't require model detail
RETURN_IF_ERROR(GetNormalizedModelConfig(
name, linfo->model_path_, min_compute_capability_,
&linfo->model_config_));
// Note that the model inputs and outputs are not validated until
// the model model is intialized as they may not be auto-completed
// until model is intialized.
RETURN_IF_ERROR(
ValidateModelConfig(linfo->model_config_, min_compute_capability_));
if (!autofill_) {
RETURN_IF_ERROR(ValidateModelIOConfig(linfo->model_config_));
}
// If the model is mapped, update its config name based on the
// mapping.
if (model_mappings_.find(name) != model_mappings_.end()) {
linfo->model_config_.set_name(name);
} else {
// If there is no model mapping, make sure the name of the model
// matches the name of the directory. This is a somewhat arbitrary
// requirement but seems like good practice to require it of the user.
// It also acts as a check to make sure we don't have two different
// models with the same name.
if (linfo->model_config_.name() != name) {
return Status(
Status::Code::INVALID_ARG,
"unexpected directory name '" + name + "' for model '" +
linfo->model_config_.name() +
"', directory name must equal model name");
}
}
*info = std::move(linfo);
return Status::Success;
}
Status
ModelRepositoryManager::UpdateDependencyGraph(
const std::set<std::string>& added, const std::set<std::string>& deleted,
const std::set<std::string>& modified,
std::set<std::string>* deleted_dependents)
{
// update dependency graph, if the state of a node is changed, all its
// downstreams will be affected
// deleted, drop from dependency_graph, add to missing_nodes if downstreams is
// not empty affected_nodes are all ensembles as only ensembles are depending
// on other models
std::set<DependencyNode*> affected_nodes;
std::set<DependencyNode*> updated_nodes;
std::set<std::string> current_deleted = deleted;
while (!current_deleted.empty()) {
std::set<std::string> next_deleted;
for (const auto& model_name : current_deleted) {
auto it = dependency_graph_.find(model_name);
if (it != dependency_graph_.end()) {
// remove this node from its upstreams
for (auto& upstream : it->second->upstreams_) {
upstream.first->downstreams_.erase(it->second.get());
// Check if the upstream should be removed as well
if ((deleted_dependents != nullptr) &&
(upstream.first->downstreams_.empty()) &&
(!upstream.first->explicitly_load_)) {
next_deleted.emplace(upstream.first->model_name_);
}
}
it->second->upstreams_.clear();
if (!it->second->downstreams_.empty()) {
UncheckDownstream(&it->second->downstreams_, &affected_nodes);
// mark this node as missing upstream in its downstreams
for (auto& downstream : it->second->downstreams_) {
downstream->missing_upstreams_.emplace(it->second.get());
}
missing_nodes_.emplace(
std::make_pair(model_name, std::move(it->second)));
}
// Make sure deleted node will not be in affected nodes
affected_nodes.erase(it->second.get());
dependency_graph_.erase(it);
}
if (deleted_dependents != nullptr) {
deleted_dependents->emplace(model_name);
}
}
current_deleted.swap(next_deleted);
}
// modified, invalidate (uncheck) all downstreams
for (const auto& model_name : modified) {
auto it = dependency_graph_.find(model_name);
if (it != dependency_graph_.end()) {
UncheckDownstream(&it->second->downstreams_, &affected_nodes);
ModelInfo* info = nullptr;
GetModelInfo(model_name, &info);
it->second->model_config_ = info->model_config_;
it->second->explicitly_load_ = info->explicitly_load_;
// remove this node from its upstream node
for (auto& upstream : it->second->upstreams_) {
upstream.first->downstreams_.erase(it->second.get());
}
it->second->upstreams_.clear();
it->second->checked_ = false;
it->second->status_ = Status::Success;
updated_nodes.emplace(it->second.get());
}
}
// added, add to dependency_graph, if in missing_node, invalidate (uncheck)
// and associate all downstreams, remove from missing_node
for (const auto& model_name : added) {
std::unique_ptr<DependencyNode> added_node;
auto it = missing_nodes_.find(model_name);
if (it != missing_nodes_.end()) {
UncheckDownstream(&it->second->downstreams_, &affected_nodes);
// remove this node from missing upstream node in its downstream nodes
for (auto& downstream : it->second->downstreams_) {
downstream->missing_upstreams_.erase(it->second.get());
}
it->second->checked_ = false;
added_node = std::move(it->second);
missing_nodes_.erase(it);
} else {
// Right now, nothing is going to be filled until validation
added_node.reset(new DependencyNode(model_name));
}
ModelInfo* info = nullptr;
GetModelInfo(model_name, &info);
added_node->model_config_ = info->model_config_;
added_node->explicitly_load_ = info->explicitly_load_;
updated_nodes.emplace(added_node.get());
dependency_graph_.emplace(
std::make_pair(model_name, std::move(added_node)));
}
auto& affected_ensembles = affected_nodes;
for (auto& updated_node : updated_nodes) {
bool is_ensemble = ConnectDependencyGraph(updated_node);
if (is_ensemble) {
affected_ensembles.emplace(updated_node);
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// After the dependency graph is updated, check ensemble dependencies
for (auto& ensemble : affected_ensembles) {
if (ensemble->status_.IsOk()) {
if (!ensemble->missing_upstreams_.empty()) {
std::string name_list;
for (auto it = ensemble->missing_upstreams_.begin();
it != ensemble->missing_upstreams_.end(); it++) {
if (it != ensemble->missing_upstreams_.begin()) {
name_list += ", ";
}
name_list += (*it)->model_name_;
}
ensemble->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble " + ensemble->model_name_ +
" contains models that are not available: " + name_list);
} else {
ensemble->status_ = CircularcyCheck(ensemble, ensemble);
}
}
}
#endif // TRITON_ENABLE_ENSEMBLE
return Status::Success;
}
Status
ModelRepositoryManager::RegisterModelRepository(
const std::string& repository,
const std::unordered_map<std::string, std::string>& model_mapping)
{
if (!model_control_enabled_) {
return Status(
Status::Code::UNSUPPORTED,
"repository registration is not allowed if model control mode is not "
"EXPLICIT");
}
bool is_directory = false;
auto status = IsDirectory(repository, &is_directory);
if (!status.IsOk() || !is_directory) {
return Status(
Status::Code::INVALID_ARG, (std::string("failed to register '") +
repository + "', repository not found")
.c_str());
}
{
// Serialize all operations that change model state
std::lock_guard<std::mutex> lock(poll_mu_);
// Check repository and mapped models do not yet exist.
if (repository_paths_.find(repository) != repository_paths_.end()) {
return Status(
Status::Code::ALREADY_EXISTS,
"model repository '" + repository + "' has already been registered");
}
for (const auto& mapping : model_mapping) {
if (model_mappings_.find(mapping.first) != model_mappings_.end()) {
return Status(
Status::Code::ALREADY_EXISTS,
(std::string("failed to register '") + mapping.first +
"', there is a conflicting mapping for '" +
std::string(mapping.first) + "'")
.c_str());
}
}
repository_paths_.emplace(repository);
for (const auto& mapping : model_mapping) {
model_mappings_.emplace(
mapping.first,
std::make_pair(repository, JoinPath({repository, mapping.second})));
}
}
LOG_INFO << "Model repository registered: " << repository;
return Status::Success;
}
Status
ModelRepositoryManager::UnregisterModelRepository(const std::string& repository)
{
if (!model_control_enabled_) {
return Status(
Status::Code::UNSUPPORTED,
"repository unregistration is not allowed if model control mode is not "
"EXPLICIT");
}
{
std::lock_guard<std::mutex> lock(poll_mu_);
if (repository_paths_.erase(repository) != 1) {
return Status(
Status::Code::INVALID_ARG,
"failed to unregister '" + repository + "', repository not found");
}
std::set<std::string> models_to_delete;
for (auto const& mapping : model_mappings_) {
if (mapping.second.first == repository) {
models_to_delete.insert(mapping.first);
}
}
for (auto const& model : models_to_delete) {
model_mappings_.erase(model);
}
}
LOG_INFO << "Model repository unregistered: " << repository;
return Status::Success;
}
Status
ModelRepositoryManager::CircularcyCheck(
DependencyNode* current_node, const DependencyNode* start_node)
{
for (auto& downstream : current_node->downstreams_) {
if (downstream->model_name_ == start_node->model_name_) {
return Status(
Status::Code::INVALID_ARG,
"circular dependency between ensembles: " + start_node->model_name_ +
" -> ... -> " + current_node->model_name_ + " -> " +
start_node->model_name_);
} else {
const auto status = CircularcyCheck(downstream, start_node);
if (!status.IsOk() && current_node->status_.IsOk()) {
current_node->status_ = status;
return status;
}
}
}
return Status::Success;
}
void
ModelRepositoryManager::UncheckDownstream(
NodeSet* downstreams, NodeSet* updated_nodes)
{
// Mark downstream nodes as unchecked recursively
for (auto& node : *downstreams) {
if (node->checked_) {
node->checked_ = false;
node->status_ = Status::Success;
UncheckDownstream(&node->downstreams_, updated_nodes);
updated_nodes->emplace(node);
}
}
}
bool
ModelRepositoryManager::ConnectDependencyGraph(DependencyNode* updated_node)
{
// Check the node's model config to determine if it depends on other models
// and if those models are present
updated_node->upstreams_.clear();
updated_node->missing_upstreams_.clear();
if (updated_node->model_config_.has_ensemble_scheduling()) {
for (const auto& step :
updated_node->model_config_.ensemble_scheduling().step()) {
DependencyNode* upstream_node = nullptr;
const auto& model_name = step.model_name();
auto dit = dependency_graph_.find(model_name);
if (dit == dependency_graph_.end()) {
auto mit = missing_nodes_.find(model_name);
if (mit == missing_nodes_.end()) {
std::unique_ptr<DependencyNode> node(new DependencyNode(model_name));
updated_node->missing_upstreams_.emplace(node.get());
mit = missing_nodes_.emplace(model_name, std::move(node)).first;
}
// Add the node to missing node's downstream so that when the missing
// node is added, the downstreams can be found easily.
mit->second->downstreams_.emplace(updated_node);
upstream_node = mit->second.get();
} else {
dit->second->downstreams_.emplace(updated_node);
upstream_node = dit->second.get();
}
auto res = updated_node->upstreams_.emplace(
upstream_node, std::set<int64_t>({step.model_version()}));
// If map insertion doesn't happen, the same model is required in
// different step, insert the version to existing required version set.
if (!res.second) {
res.first->second.insert(step.model_version());
}
}
return true;
}
return false;
}
Status
ModelRepositoryManager::GetModelInfo(
const std::string& name, ModelInfo** model_info)
{
const auto itr = infos_.find(name);
if (itr == infos_.end()) {
return Status(
Status::Code::NOT_FOUND, "no configuration for model '" + name + "'");
}
*model_info = itr->second.get();
return Status::Success;
}
std::pair<ModelRepositoryManager::NodeSet, ModelRepositoryManager::NodeSet>
ModelRepositoryManager::ModelsToLoadUnload(const NodeSet& loaded_models)
{
// <valid model set, invalid model set>
std::pair<NodeSet, NodeSet> res;
// first call to this function
if (loaded_models.empty()) {
for (auto& pair : dependency_graph_) {
auto node = pair.second.get();
// only care about nodes that are affected by the update
if (!node->checked_) {
if (CheckNode(node)) {
if (node->status_.IsOk()) {
res.first.emplace(node);
} else {
res.second.emplace(node);
}
}
}
}
} else {
for (const auto& model : loaded_models) {
for (auto node : model->downstreams_) {
// only care about nodes that are affected by the update
if (!node->checked_) {
if (CheckNode(node)) {
if (node->status_.IsOk()) {
res.first.emplace(node);
} else {
res.second.emplace(node);
}
}
}
}
}
}
for (auto& node : res.first) {
node->checked_ = true;
}
for (auto& node : res.second) {
node->checked_ = true;
}
return res;
}
bool
ModelRepositoryManager::CheckNode(DependencyNode* node)
{
bool node_ready = true;
// if the node is in invalid status, mark as ready as we know
// it should not be loaded
if (node->status_.IsOk()) {
for (auto& upstream : node->upstreams_) {
if (!upstream.first->checked_) {
node_ready = false;
break;
}
if (!upstream.first->status_.IsOk()) {
node->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble '" + node->model_name_ + "' depends on '" +
upstream.first->model_name_ + "' which is not valid");
} else if (upstream.first->loaded_versions_.empty()) {
node->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble '" + node->model_name_ + "' depends on '" +
upstream.first->model_name_ + "' which has no loaded version");
} else {
for (const auto& required_version : upstream.second) {
if (required_version == -1) {
continue;
}
auto it = upstream.first->loaded_versions_.find(required_version);
if (it == upstream.first->loaded_versions_.end()) {
node->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble '" + node->model_name_ + "' depends on '" +
upstream.first->model_name_ + "' whose required version " +
std::to_string(required_version) + " is not loaded");
}
}
}
if (!node->status_.IsOk()) {
break;
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// Validate ensemble config if the node is ready. By this point, the
// depending models are loaded and their configs are completed
if (node_ready && node->status_.IsOk()) {
node->status_ = ValidateEnsembleConfig(this, node);
}
#endif // TRITON_ENABLE_ENSEMBLE
}
return node_ready;
}
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include <set>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "model_lifecycle.h"
#include "status.h"
#include "triton/common/model_config.h"
namespace triton { namespace core {
class InferenceServer;
class Model;
// [FIXME] should have separated load / unload functions for clarity
enum ActionType { NO_ACTION, LOAD, UNLOAD };
/// Predefined reason strings
#define MODEL_READY_REASON_DUPLICATE "model appears in two or more repositories"
/// An object to manage the model repository active in the server.
class ModelRepositoryManager {
public:
// Index information for a model.
struct ModelIndex {
ModelIndex(const std::string& n)
: name_only_(true), name_(n), version_(-1),
state_(ModelReadyState::UNKNOWN)
{
}
ModelIndex(
const std::string& n, const int64_t v, const ModelReadyState s,
const std::string& r)
: name_only_(false), name_(n), version_(v), state_(s), reason_(r)
{
}
const bool name_only_;
const std::string name_;
const int64_t version_;
const ModelReadyState state_;
const std::string reason_;
};
/// A basic unit in dependency graph that records the models seen by the model
/// repository manager.
struct DependencyNode {
DependencyNode(const std::string& model_name)
: model_name_(model_name), status_(Status::Success), checked_(false)
{
}
std::string model_name_;
Status status_;
bool checked_;
bool explicitly_load_;
inference::ModelConfig model_config_;
std::set<int64_t> loaded_versions_;
std::set<DependencyNode*> missing_upstreams_;
std::unordered_map<DependencyNode*, std::set<int64_t>> upstreams_;
std::set<DependencyNode*> downstreams_;
};
~ModelRepositoryManager();
/// Create a manager for a repository.
/// \param server The pointer to the inference server.
/// \param server_version The version of the inference server.
/// \param repository_paths A set of file-system paths of the repositories.
/// \param startup_models A set of models to be loaded at startup
/// if model control is enabled.
/// \param strict_model_config If false attempt to autofill missing required
/// information in each model configuration.
/// \param polling_enabled If true, then PollAndUpdate() is allowed.
/// Otherwise, it is not allowed.
/// \param model_control_enabled If true, then LoadUnloadModel() is allowed
/// and the models in the model repository will not be loaded at startup.
/// Otherwise, LoadUnloadModel() is not allowed and the models will be loaded.
/// Cannot be set to true if polling_enabled is true.
/// \param life_cycle_options The options to configure ModelLifeCycle.
/// \param model_repository_manager Return the model repository manager.
/// \return The error status.
static Status Create(
InferenceServer* server, const std::string& server_version,
const std::set<std::string>& repository_paths,
const std::set<std::string>& startup_models,
const bool strict_model_config, const bool polling_enabled,
const bool model_control_enabled,
const ModelLifeCycleOptions& life_cycle_options,
std::unique_ptr<ModelRepositoryManager>* model_repository_manager);
/// Poll the model repository to determine the new set of models and
/// compare with the current set. And serve the new set of models based
/// on their version policy.
Status PollAndUpdate();
/// Load or unload a specified model.
/// \param models The models and the parameters to be loaded or unloaded
/// \param type The type action to be performed. If the action is LOAD and
/// the model has been loaded, the model will be re-loaded.
/// \return error status. Return "NOT_FOUND" if it tries to load
/// a non-existing model or if it tries to unload a model that hasn't been
/// loaded.
Status LoadUnloadModel(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
const ActionType type, const bool unload_dependents);
/// Unload all models. This function should be called before shutting down
/// the model repository manager.
/// \return error status.
Status UnloadAllModels();
/// Instruct all models to stop accepting new inference requests. However,
/// the models are still capable of processing inference requests
/// if the model considers them as part of the in-flight inference.
/// \return error status.
Status StopAllModels();
/// \return the number of in-flight inferences for the all versions of all
/// models. The set element will be a tuple of <model_name, model_version,
/// in-flight inference count>. Note that a model version will not be included
/// if it doesn't have in-flight inferences.
const std::set<std::tuple<std::string, int64_t, size_t>> InflightStatus();
/// \param strict_readiness If true, only models that have at least one
/// ready version will be considered as live. Otherwise, the models that
/// have loading / unloading versions will also be live.
/// \return the state of all versions of all live models.
const ModelStateMap LiveModelStates(bool strict_readiness = false);
/// \return the state of all versions of all models that have every
/// been (attempted) loaded over the lifetime of the server.
const ModelStateMap ModelStates();
/// \return the states of all versions of a specific model.
const VersionStateMap VersionStates(const std::string& model_name);
/// \return the ready-state of a specific model version.
Status ModelState(
const std::string& model_name, const int64_t model_version,
ModelReadyState* state);
/// Get the index of all models in all repositories.
/// \param ready_only If true return only index of models that are ready.
/// \param index Returns the index.
/// \return error status.
Status RepositoryIndex(const bool ready_only, std::vector<ModelIndex>* index);
/// Obtain the specified model.
/// \param model_name The name of the model.
/// \param model_version The version of the model.
/// \param model Return the model object.
/// \return error status.
Status GetModel(
const std::string& model_name, const int64_t model_version,
std::shared_ptr<Model>* model);
// Register model repository path.
/// \param repository Path to model repository.
/// \param model_mapping Mapping with (overridden) model name as key, subdir
/// name as value.
/// \return error status
Status RegisterModelRepository(
const std::string& repository,
const std::unordered_map<std::string, std::string>& model_mapping);
// Unregister model repository path.
/// \param repository Path to model repository.
/// \return error status
Status UnregisterModelRepository(const std::string& repository);
private:
struct ModelInfo;
// Map from model name to information about the model.
using ModelInfoMap =
std::unordered_map<std::string, std::unique_ptr<ModelInfo>>;
// Set of DependencyNode
using NodeSet = std::set<DependencyNode*>;
ModelRepositoryManager(
const std::set<std::string>& repository_paths, const bool autofill,
const bool polling_enabled, const bool model_control_enabled,
const double min_compute_capability,
std::unique_ptr<ModelLifeCycle> life_cycle);
/// The internal function that are called in Create() and PollAndUpdate().
Status PollAndUpdateInternal(bool* all_models_polled);
/// The internal function that load or unload a set of models.
Status LoadUnloadModels(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
const ActionType type, const bool unload_dependents,
bool* all_models_polled);
/// Poll the requested models in the model repository and
/// compare with the current set. Return the additions, deletions,
/// and modifications that have occurred. This function will not updated
/// the current model info, it is caller's responsibility to do so.
/// \param models The map from models to be polled to their associated
/// parameters.
/// \param added The names of the models added to the repository.
/// \param deleted The names of the models removed from the repository.
/// \param modified The names of the models remaining in the
/// repository that have been changed.
/// \param unmodified The names of the models remaining in the
/// repository that have not changed.
/// \param updated_infos The model infos retrieved from the poll.
/// \param all_models_polled Return true if all models are polled and
/// their model configuration are validated successfully. Instead of aborting
/// the polling, the models that fail will be ignored and their model infos
/// will stay in the previous state.
/// \return The error status.
Status Poll(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
std::set<std::string>* added, std::set<std::string>* deleted,
std::set<std::string>* modified, std::set<std::string>* unmodified,
ModelInfoMap* updated_infos, bool* all_models_polled);
/// Helper function for Poll() to initialize ModelInfo for the model.
/// \param name The name of the model.
/// \param path The model path. Empty path means the model is provided via
/// 'params'
/// \param params The model parameters provided for polling model.
/// \param info Return the updated ModelInfo. 'nullptr' will be returned if
/// existing ModelInfo for the model should be reused.
/// \return The error status.
Status InitializeModelInfo(
const std::string& name, const std::string& path,
const std::vector<const InferenceParameter*>& params,
std::unique_ptr<ModelInfo>* info);
/// Load models based on the dependency graph. The function will iteratively
/// load models that all the models they depend on has been loaded, and unload
/// models if their dependencies are no longer satisfied.
/// \return The status of the model loads.
std::map<std::string, Status> LoadModelByDependency();
/// Helper function to update the dependency graph based on the poll result
/// \param added The names of the models added to the repository.
/// \param deleted The names of the models removed from the repository.
/// \param modified The names of the models remaining in the
/// repository that have been changed.
/// \param deleted_dependents The names of dependent models to be removed
/// from the repository.
/// \return The error status.
Status UpdateDependencyGraph(
const std::set<std::string>& added, const std::set<std::string>& deleted,
const std::set<std::string>& modified,
std::set<std::string>* deleted_dependents = nullptr);
/// Helper function to uncheck the nodes because the model that they depends
/// on has changed. The unchecked nodes will be validated again.
/// The function will be call recursively to uncheck all downstreams.
/// \param downstreams The nodes to be unchecked.
/// \param updated_nodes Return the nodes that have been unchecked
void UncheckDownstream(NodeSet* downstreams, NodeSet* updated_nodes);
/// Helper function to construct the edges between nodes in dependency graph.
/// \param updated_node The node that is newly added or modified.
/// \return True if the node represents an ensemble model. False otherwise.
bool ConnectDependencyGraph(DependencyNode* updated_node);
/// Get the model info for a named model.
/// \param name The model name.
/// \param model_info Returns the model information.
/// \return OK if found, NOT_FOUND otherwise.
Status GetModelInfo(const std::string& name, ModelInfo** model_info);
/// Get the models to be loaded / unloaded based on the model loaded in
/// previous iteration.
/// \param loaded_models The models loaded / unloaded in previous iteration.
/// Unloaded models will be represented as models with no loaded versions.
/// \return A pair of node set containing models to be loaded and models to be
/// unloaded for the next iteration.
std::pair<NodeSet, NodeSet> ModelsToLoadUnload(const NodeSet& loaded_models);
/// Check if the node is ready for the next iteration. A node is ready if the
/// node is invalid (containing invalid model config or its depdencies failed
/// to load) or all of its dependencies are satisfied.
/// \param node The node to be checked.
/// \return True if the node is ready. False otherwise.
bool CheckNode(DependencyNode* node);
Status CircularcyCheck(
DependencyNode* current_node, const DependencyNode* start_node);
bool ModelDirectoryOverride(
const std::vector<const InferenceParameter*>& model_params);
std::set<std::string> repository_paths_;
const bool autofill_;
const bool polling_enabled_;
const bool model_control_enabled_;
const double min_compute_capability_;
std::mutex poll_mu_;
ModelInfoMap infos_;
std::unordered_map<std::string, std::unique_ptr<DependencyNode>>
dependency_graph_;
std::unordered_map<std::string, std::unique_ptr<DependencyNode>>
missing_nodes_;
// Mappings from (overridden) model names to a pair of their repository and
// absolute path
std::unordered_map<std::string, std::pair<std::string, std::string>>
model_mappings_;
std::unique_ptr<ModelLifeCycle> model_life_cycle_;
};
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "numa_utils.h"
#ifndef _WIN32
#include <numa.h>
#include <numaif.h>
#endif
#include "triton/common/logging.h"
namespace triton { namespace core {
namespace {
std::string
VectorToString(const std::vector<int>& vec)
{
std::string str("[");
for (const auto& element : vec) {
str += std::to_string(element);
str += ",";
}
str += "]";
return str;
}
Status
ParseIntOption(const std::string& msg, const std::string& arg, int* value)
{
try {
*value = std::stoi(arg);
}
catch (const std::invalid_argument& ia) {
return Status(
Status::Code::INVALID_ARG,
msg + ": Can't parse '" + arg + "' to integer");
}
return Status::Success;
}
} // namespace
// NUMA setting will be ignored on Windows platform
#ifdef _WIN32
Status
SetNumaConfigOnThread(
const triton::common::HostPolicyCmdlineConfig& host_policy)
{
return Status::Success;
}
Status
SetNumaMemoryPolicy(const triton::common::HostPolicyCmdlineConfig& host_policy)
{
return Status::Success;
}
Status
GetNumaMemoryPolicyNodeMask(unsigned long* node_mask)
{
*node_mask = 0;
return Status::Success;
}
Status
ResetNumaMemoryPolicy()
{
return Status::Success;
}
Status
SetNumaThreadAffinity(
std::thread::native_handle_type thread,
const triton::common::HostPolicyCmdlineConfig& host_policy)
{
return Status::Success;
}
#else
// Use variable to make sure no NUMA related function is actually called
// if Triton is not running with NUMA awareness. i.e. Extra docker permission
// is needed to call the NUMA functions and this ensures backward compatibility.
thread_local bool numa_set = false;
Status
SetNumaConfigOnThread(
const triton::common::HostPolicyCmdlineConfig& host_policy)
{
// Set thread affinity
RETURN_IF_ERROR(SetNumaThreadAffinity(pthread_self(), host_policy));
// Set memory policy
RETURN_IF_ERROR(SetNumaMemoryPolicy(host_policy));
return Status::Success;
}
Status
SetNumaMemoryPolicy(const triton::common::HostPolicyCmdlineConfig& host_policy)
{
const auto it = host_policy.find("numa-node");
if (it != host_policy.end()) {
int node_id;
RETURN_IF_ERROR(
ParseIntOption("Parsing 'numa-node' value", it->second, &node_id));
LOG_VERBOSE(1) << "Thread is binding to NUMA node " << it->second
<< ". Max NUMA node count: " << (numa_max_node() + 1);
numa_set = true;
unsigned long node_mask = 1UL << node_id;
if (set_mempolicy(MPOL_BIND, &node_mask, (numa_max_node() + 1) + 1) != 0) {
return Status(
Status::Code::INTERNAL,
std::string("Unable to set NUMA memory policy: ") + strerror(errno));
}
}
return Status::Success;
}
Status
GetNumaMemoryPolicyNodeMask(unsigned long* node_mask)
{
*node_mask = 0;
int mode;
if (numa_set &&
get_mempolicy(&mode, node_mask, numa_max_node() + 1, NULL, 0) != 0) {
return Status(
Status::Code::INTERNAL,
std::string("Unable to get NUMA node for current thread: ") +
strerror(errno));
}
return Status::Success;
}
Status
ResetNumaMemoryPolicy()
{
if (numa_set && (set_mempolicy(MPOL_DEFAULT, nullptr, 0) != 0)) {
return Status(
Status::Code::INTERNAL,
std::string("Unable to reset NUMA memory policy: ") + strerror(errno));
}
numa_set = false;
return Status::Success;
}
Status
SetNumaThreadAffinity(
std::thread::native_handle_type thread,
const triton::common::HostPolicyCmdlineConfig& host_policy)
{
const auto it = host_policy.find("cpu-cores");
if (it != host_policy.end()) {
// Parse CPUs
std::vector<int> cpus;
{
const auto& cpu_str = it->second;
auto delim_cpus = cpu_str.find(",");
int current_pos = 0;
while (true) {
auto delim_range = cpu_str.find("-", current_pos);
if (delim_range == std::string::npos) {
return Status(
Status::Code::INVALID_ARG,
std::string("host policy setting 'cpu-cores' format is "
"'<lower_cpu_core_id>-<upper_cpu_core_id>'. Got ") +
cpu_str.substr(
current_pos, ((delim_cpus == std::string::npos)
? (cpu_str.length() + 1)
: delim_cpus) -
current_pos));
}
int lower, upper;
RETURN_IF_ERROR(ParseIntOption(
"Parsing 'cpu-cores' value",
cpu_str.substr(current_pos, delim_range - current_pos), &lower));
RETURN_IF_ERROR(ParseIntOption(
"Parsing 'cpu-cores' value",
(delim_cpus == std::string::npos)
? cpu_str.substr(delim_range + 1)
: cpu_str.substr(
delim_range + 1, delim_cpus - (delim_range + 1)),
&upper));
for (; lower <= upper; ++lower) {
cpus.push_back(lower);
}
// break if the processed range is the last specified range
if (delim_cpus != std::string::npos) {
current_pos = delim_cpus + 1;
delim_cpus = cpu_str.find(",", current_pos);
} else {
break;
}
}
}
LOG_VERBOSE(1) << "Thread is binding to one of the CPUs: "
<< VectorToString(cpus);
numa_set = true;
cpu_set_t cpuset;
CPU_ZERO(&cpuset);
for (int cpu : cpus) {
CPU_SET(cpu, &cpuset);
}
if (pthread_setaffinity_np(thread, sizeof(cpu_set_t), &cpuset) != 0) {
return Status(
Status::Code::INTERNAL,
std::string("Unable to set NUMA thread affinity: ") +
strerror(errno));
}
}
return Status::Success;
}
#endif
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <map>
#include <thread>
#include <vector>
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace triton { namespace core {
// Helper function to set memory policy and thread affinity on current thread
Status SetNumaConfigOnThread(
const triton::common::HostPolicyCmdlineConfig& host_policy);
// Restrict the memory allocation to specific NUMA node.
Status SetNumaMemoryPolicy(
const triton::common::HostPolicyCmdlineConfig& host_policy);
// Retrieve the node mask used to set memory policy for the current thread
Status GetNumaMemoryPolicyNodeMask(unsigned long* node_mask);
// Reset the memory allocation setting.
Status ResetNumaMemoryPolicy();
// Set a thread affinity to be on specific cpus.
Status SetNumaThreadAffinity(
std::thread::native_handle_type thread,
const triton::common::HostPolicyCmdlineConfig& host_policy);
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "payload.h"
namespace triton { namespace core {
Payload::Payload()
: op_type_(Operation::INFER_RUN),
requests_(std::vector<std::unique_ptr<InferenceRequest>>()),
OnCallback_([]() {}), instance_(nullptr), state_(State::UNINITIALIZED),
batcher_start_ns_(0), saturated_(false)
{
exec_mu_.reset(new std::mutex());
}
const Status&
Payload::MergePayload(std::shared_ptr<Payload>& payload)
{
if ((payload->GetOpType() != Operation::INFER_RUN) ||
(op_type_ != Operation::INFER_RUN)) {
static Status op_type_error(
Status::Code::INTERNAL,
"Attempted to merge payloads of type that are not INFER_RUN");
return op_type_error;
}
if (payload->GetInstance() != instance_) {
static Status instance_error(
Status::Code::INTERNAL,
"Attempted to merge payloads of mismatching instance");
return instance_error;
}
if ((payload->GetState() != State::EXECUTING) ||
(state_ != State::EXECUTING)) {
static Status state_error(
Status::Code::INTERNAL,
"Attempted to merge payloads that are not in executing state");
return state_error;
}
// Skip comparison if not initialized (required), here assume either all
// payloads are initialized or otherwise.
if (required_equal_inputs_.Initialized() &&
!required_equal_inputs_.HasEqualInputs(*payload->Requests().begin())) {
static Status shape_error(
Status::Code::INVALID_ARG,
"Attempted to merge payloads that has non-equal inputs");
return shape_error;
}
requests_.insert(
requests_.end(), std::make_move_iterator(payload->Requests().begin()),
std::make_move_iterator(payload->Requests().end()));
payload->Callback();
return Status::Success;
}
void
Payload::Reset(const Operation op_type, TritonModelInstance* instance)
{
op_type_ = op_type;
requests_.clear();
OnCallback_ = []() {};
release_callbacks_.clear();
instance_ = instance;
state_ = State::UNINITIALIZED;
status_.reset(new std::promise<Status>());
required_equal_inputs_ = RequiredEqualInputs();
batcher_start_ns_ = 0;
saturated_ = false;
}
void
Payload::Release()
{
op_type_ = Operation::INFER_RUN;
requests_.clear();
OnCallback_ = []() {};
release_callbacks_.clear();
instance_ = nullptr;
state_ = State::RELEASED;
required_equal_inputs_ = RequiredEqualInputs();
batcher_start_ns_ = 0;
saturated_ = false;
}
size_t
Payload::BatchSize()
{
size_t batch_size = 0;
for (const auto& request : requests_) {
batch_size += std::max(1U, request->BatchSize());
}
return batch_size;
}
void
Payload::ReserveRequests(size_t size)
{
requests_.reserve(size);
}
void
Payload::AddRequest(std::unique_ptr<InferenceRequest> request)
{
if ((batcher_start_ns_ == 0) ||
(batcher_start_ns_ > request->BatcherStartNs())) {
batcher_start_ns_ = request->BatcherStartNs();
}
requests_.push_back(std::move(request));
}
void
Payload::SetCallback(std::function<void()> OnCallback)
{
OnCallback_ = OnCallback;
}
void
Payload::SetInstance(TritonModelInstance* model_instance)
{
instance_ = model_instance;
}
void
Payload::AddInternalReleaseCallback(std::function<void()>&& callback)
{
release_callbacks_.emplace_back(std::move(callback));
}
void
Payload::MarkSaturated()
{
saturated_ = true;
}
void
Payload::SetState(Payload::State state)
{
state_ = state;
}
Status
Payload::Wait()
{
return status_->get_future().get();
}
void
Payload::Callback()
{
OnCallback_();
}
void
Payload::OnRelease()
{
// Invoke the release callbacks added internally before releasing the
// request to user provided callback.
for (auto it = release_callbacks_.rbegin(); it != release_callbacks_.rend();
it++) {
(*it)();
}
release_callbacks_.clear();
}
void
Payload::Execute(bool* should_exit)
{
*should_exit = false;
Status status;
switch (op_type_) {
case Operation::INFER_RUN:
instance_->Schedule(std::move(requests_), OnCallback_);
break;
case Operation::INIT:
status = instance_->Initialize();
break;
case Operation::WARM_UP:
status = instance_->WarmUp();
break;
case Operation::EXIT:
*should_exit = true;
}
status_->set_value(status);
}
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <vector>
#include "backend_model_instance.h"
#include "infer_request.h"
#include "scheduler_utils.h"
#include "status.h"
namespace triton { namespace core {
class Payload {
public:
enum Operation { INFER_RUN = 0, INIT = 1, WARM_UP = 2, EXIT = 3 };
enum State {
UNINITIALIZED = 0,
READY = 1,
REQUESTED = 2,
SCHEDULED = 3,
EXECUTING = 4,
RELEASED = 5
};
Payload();
void Reset(const Operation op_type, TritonModelInstance* instance = nullptr);
const Status& MergePayload(std::shared_ptr<Payload>& payload);
Operation GetOpType() { return op_type_; }
std::mutex* GetExecMutex() { return exec_mu_.get(); }
size_t RequestCount() { return requests_.size(); }
size_t BatchSize();
void ReserveRequests(size_t size);
void AddRequest(std::unique_ptr<InferenceRequest> request);
std::vector<std::unique_ptr<InferenceRequest>>& Requests()
{
return requests_;
}
uint64_t BatcherStartNs() { return batcher_start_ns_; }
void SetCallback(std::function<void()> OnCallback);
void Callback();
void AddInternalReleaseCallback(std::function<void()>&& callback);
void OnRelease();
void SetInstance(TritonModelInstance* model_instance);
TritonModelInstance* GetInstance() { return instance_; }
void MarkSaturated();
bool IsSaturated() { return saturated_; }
RequiredEqualInputs* MutableRequiredEqualInputs()
{
return &required_equal_inputs_;
}
State GetState() { return state_; }
void SetState(State state);
void Execute(bool* should_exit);
Status Wait();
void Release();
private:
Operation op_type_;
std::vector<std::unique_ptr<InferenceRequest>> requests_;
std::function<void()> OnCallback_;
std::vector<std::function<void()>> release_callbacks_;
TritonModelInstance* instance_;
State state_;
std::unique_ptr<std::promise<Status>> status_;
std::unique_ptr<std::mutex> exec_mu_;
uint64_t batcher_start_ns_;
RequiredEqualInputs required_equal_inputs_;
bool saturated_;
};
}} // namespace triton::core
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "pinned_memory_manager.h"
#include <sstream>
#include "numa_utils.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace triton { namespace core {
namespace {
std::string
PointerToString(void* ptr)
{
std::stringstream ss;
ss << ptr;
return ss.str();
}
Status
ParseIntOption(const std::string& msg, const std::string& arg, int* value)
{
try {
*value = std::stoi(arg);
}
catch (const std::invalid_argument& ia) {
return Status(
Status::Code::INVALID_ARG,
msg + ": Can't parse '" + arg + "' to integer");
}
return Status::Success;
}
} // namespace
std::unique_ptr<PinnedMemoryManager> PinnedMemoryManager::instance_;
uint64_t PinnedMemoryManager::pinned_memory_byte_size_;
PinnedMemoryManager::PinnedMemory::PinnedMemory(
void* pinned_memory_buffer, uint64_t size)
: pinned_memory_buffer_(pinned_memory_buffer)
{
if (pinned_memory_buffer_ != nullptr) {
managed_pinned_memory_ = boost::interprocess::managed_external_buffer(
boost::interprocess::create_only_t{}, pinned_memory_buffer_, size);
}
}
PinnedMemoryManager::PinnedMemory::~PinnedMemory()
{
#ifdef TRITON_ENABLE_GPU
if (pinned_memory_buffer_ != nullptr) {
cudaFreeHost(pinned_memory_buffer_);
}
#endif // TRITON_ENABLE_GPU
}
PinnedMemoryManager::~PinnedMemoryManager()
{
// Clean up
for (const auto& memory_info : memory_info_) {
const auto& is_pinned = memory_info.second.first;
if (!is_pinned) {
free(memory_info.first);
}
}
}
void
PinnedMemoryManager::AddPinnedMemoryBuffer(
const std::shared_ptr<PinnedMemory>& pinned_memory_buffer,
unsigned long node_mask)
{
pinned_memory_buffers_[node_mask] = pinned_memory_buffer;
}
Status
PinnedMemoryManager::AllocInternal(
void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type,
bool allow_nonpinned_fallback, PinnedMemory* pinned_memory_buffer)
{
auto status = Status::Success;
if (pinned_memory_buffer->pinned_memory_buffer_ != nullptr) {
std::lock_guard<std::mutex> lk(pinned_memory_buffer->buffer_mtx_);
*ptr = pinned_memory_buffer->managed_pinned_memory_.allocate(
size, std::nothrow_t{});
*allocated_type = TRITONSERVER_MEMORY_CPU_PINNED;
if (*ptr == nullptr) {
status = Status(
Status::Code::INTERNAL, "failed to allocate pinned system memory");
}
} else {
status = Status(
Status::Code::INTERNAL,
"failed to allocate pinned system memory: no pinned memory pool");
}
bool is_pinned = true;
if ((!status.IsOk()) && allow_nonpinned_fallback) {
static bool warning_logged = false;
if (!warning_logged) {
LOG_WARNING << status.Message()
<< ", falling back to non-pinned system memory";
warning_logged = true;
}
*ptr = malloc(size);
*allocated_type = TRITONSERVER_MEMORY_CPU;
is_pinned = false;
if (*ptr == nullptr) {
status = Status(
Status::Code::INTERNAL,
"failed to allocate non-pinned system memory");
} else {
status = Status::Success;
}
}
// keep track of allocated buffer or clean up
{
std::lock_guard<std::mutex> lk(info_mtx_);
if (status.IsOk()) {
auto res = memory_info_.emplace(
*ptr, std::make_pair(is_pinned, pinned_memory_buffer));
if (!res.second) {
status = Status(
Status::Code::INTERNAL, "unexpected memory address collision, '" +
PointerToString(*ptr) +
"' has been managed");
}
LOG_VERBOSE(1) << (is_pinned ? "" : "non-")
<< "pinned memory allocation: "
<< "size " << size << ", addr " << *ptr;
}
}
if ((!status.IsOk()) && (*ptr != nullptr)) {
if (is_pinned) {
std::lock_guard<std::mutex> lk(pinned_memory_buffer->buffer_mtx_);
pinned_memory_buffer->managed_pinned_memory_.deallocate(*ptr);
} else {
free(*ptr);
}
}
return status;
}
Status
PinnedMemoryManager::FreeInternal(void* ptr)
{
bool is_pinned = true;
PinnedMemory* pinned_memory_buffer = nullptr;
{
std::lock_guard<std::mutex> lk(info_mtx_);
auto it = memory_info_.find(ptr);
if (it != memory_info_.end()) {
is_pinned = it->second.first;
pinned_memory_buffer = it->second.second;
LOG_VERBOSE(1) << (is_pinned ? "" : "non-")
<< "pinned memory deallocation: "
<< "addr " << ptr;
memory_info_.erase(it);
} else {
return Status(
Status::Code::INTERNAL, "unexpected memory address '" +
PointerToString(ptr) +
"' is not being managed");
}
}
if (is_pinned) {
std::lock_guard<std::mutex> lk(pinned_memory_buffer->buffer_mtx_);
pinned_memory_buffer->managed_pinned_memory_.deallocate(ptr);
} else {
free(ptr);
}
return Status::Success;
}
void
PinnedMemoryManager::Reset()
{
instance_.reset();
}
Status
PinnedMemoryManager::Create(const Options& options)
{
if (instance_ != nullptr) {
LOG_WARNING << "New pinned memory pool of size "
<< options.pinned_memory_pool_byte_size_
<< " could not be created since one already exists"
<< " of size " << pinned_memory_byte_size_;
return Status::Success;
}
instance_.reset(new PinnedMemoryManager());
if (options.host_policy_map_.empty()) {
void* buffer = nullptr;
#ifdef TRITON_ENABLE_GPU
auto err = cudaHostAlloc(
&buffer, options.pinned_memory_pool_byte_size_, cudaHostAllocPortable);
if (err != cudaSuccess) {
buffer = nullptr;
LOG_WARNING << "Unable to allocate pinned system memory, pinned memory "
"pool will not be available: "
<< std::string(cudaGetErrorString(err));
} else if (options.pinned_memory_pool_byte_size_ != 0) {
LOG_INFO << "Pinned memory pool is created at '"
<< PointerToString(buffer) << "' with size "
<< options.pinned_memory_pool_byte_size_;
} else {
LOG_INFO << "Pinned memory pool disabled";
}
#endif // TRITON_ENABLE_GPU
try {
instance_->AddPinnedMemoryBuffer(
std::shared_ptr<PinnedMemory>(
new PinnedMemory(buffer, options.pinned_memory_pool_byte_size_)),
0);
}
catch (const std::exception& ex) {
return Status(
Status::Code::INTERNAL,
"Failed to add Pinned Memory buffer: " + std::string(ex.what()));
}
} else {
// Create only one buffer / manager should be created for one node,
// and all associated devices should request memory from the shared manager
std::map<int32_t, std::string> numa_map;
for (const auto host_policy : options.host_policy_map_) {
const auto numa_it = host_policy.second.find("numa-node");
if (numa_it != host_policy.second.end()) {
int32_t numa_id;
if (ParseIntOption("Parsing NUMA node", numa_it->second, &numa_id)
.IsOk()) {
numa_map.emplace(numa_id, host_policy.first);
}
}
}
for (const auto node_policy : numa_map) {
auto status =
SetNumaMemoryPolicy(options.host_policy_map_.at(node_policy.second));
if (!status.IsOk()) {
LOG_WARNING << "Unable to allocate pinned system memory for NUMA node "
<< node_policy.first << ": " << status.AsString();
continue;
}
unsigned long node_mask;
status = GetNumaMemoryPolicyNodeMask(&node_mask);
if (!status.IsOk()) {
LOG_WARNING << "Unable to get NUMA node set for current thread: "
<< status.AsString();
continue;
}
void* buffer = nullptr;
#ifdef TRITON_ENABLE_GPU
auto err = cudaHostAlloc(
&buffer, options.pinned_memory_pool_byte_size_,
cudaHostAllocPortable);
if (err != cudaSuccess) {
buffer = nullptr;
LOG_WARNING << "Unable to allocate pinned system memory, pinned memory "
"pool will not be available: "
<< std::string(cudaGetErrorString(err));
} else if (options.pinned_memory_pool_byte_size_ != 0) {
LOG_INFO << "Pinned memory pool is created at '"
<< PointerToString(buffer) << "' with size "
<< options.pinned_memory_pool_byte_size_;
} else {
LOG_INFO << "Pinned memory pool disabled";
}
#endif // TRITON_ENABLE_GPU
ResetNumaMemoryPolicy();
try {
instance_->AddPinnedMemoryBuffer(
std::shared_ptr<PinnedMemory>(new PinnedMemory(
buffer, options.pinned_memory_pool_byte_size_)),
node_mask);
}
catch (const std::exception& ex) {
return Status(
Status::Code::INTERNAL,
"Failed to add Pinned Memory buffer with host policy: " +
std::string(ex.what()));
}
}
// If no pinned memory is allocated, add an empty entry where all allocation
// will be on normal system memory
if (instance_->pinned_memory_buffers_.empty()) {
try {
instance_->AddPinnedMemoryBuffer(
std::shared_ptr<PinnedMemory>(new PinnedMemory(
nullptr, options.pinned_memory_pool_byte_size_)),
0);
}
catch (const std::exception& ex) {
return Status(
Status::Code::INTERNAL,
"Failed to add empty Pinned Memory entry: " +
std::string(ex.what()));
}
}
}
pinned_memory_byte_size_ = options.pinned_memory_pool_byte_size_;
return Status::Success;
}
Status
PinnedMemoryManager::Alloc(
void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type,
bool allow_nonpinned_fallback)
{
if (instance_ == nullptr) {
return Status(
Status::Code::UNAVAILABLE, "PinnedMemoryManager has not been created");
}
auto pinned_memory_buffer =
instance_->pinned_memory_buffers_.begin()->second.get();
if (instance_->pinned_memory_buffers_.size() > 1) {
unsigned long node_mask;
if (GetNumaMemoryPolicyNodeMask(&node_mask).IsOk()) {
auto it = instance_->pinned_memory_buffers_.find(node_mask);
if (it != instance_->pinned_memory_buffers_.end()) {
pinned_memory_buffer = it->second.get();
}
}
}
return instance_->AllocInternal(
ptr, size, allocated_type, allow_nonpinned_fallback,
pinned_memory_buffer);
}
Status
PinnedMemoryManager::Free(void* ptr)
{
if (instance_ == nullptr) {
return Status(
Status::Code::UNAVAILABLE, "PinnedMemoryManager has not been created");
}
return instance_->FreeInternal(ptr);
}
}} // namespace triton::core
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <boost/interprocess/managed_external_buffer.hpp>
#include <map>
#include <memory>
#include <mutex>
#include "status.h"
#include "triton/common/model_config.h"
namespace triton { namespace core {
// This is a singleton class responsible for maintaining pinned memory pool
// used by the inference server. Pinned memory allocations and deallocations
// must be requested via functions provided by this class.
class PinnedMemoryManager {
public:
// Options to configure pinned memeory manager.
struct Options {
Options(
uint64_t b = 0,
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map = {})
: pinned_memory_pool_byte_size_(b), host_policy_map_(host_policy_map)
{
}
uint64_t pinned_memory_pool_byte_size_;
triton::common::HostPolicyCmdlineConfigMap host_policy_map_;
};
~PinnedMemoryManager();
// Create the pinned memory manager based on 'options' specified.
// Return Status object indicating success or failure.
static Status Create(const Options& options);
// Allocate pinned memory with the requested 'size' and return the pointer
// in 'ptr'. If 'allow_nonpinned_fallback' is true, regular system memory
// will be allocated as fallback in the case where pinned memory fails to
// be allocated.
// Return Status object indicating success or failure.
static Status Alloc(
void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type,
bool allow_nonpinned_fallback);
// Free the memory allocated by the pinned memory manager.
// Return Status object indicating success or failure.
static Status Free(void* ptr);
protected:
// Provide explicit control on the lifecycle of the CUDA memory manager,
// for testing only.
static void Reset();
private:
class PinnedMemory {
public:
PinnedMemory(void* pinned_memory_buffer, uint64_t size);
~PinnedMemory();
void* pinned_memory_buffer_;
std::mutex buffer_mtx_;
boost::interprocess::managed_external_buffer managed_pinned_memory_;
};
PinnedMemoryManager() = default;
Status AllocInternal(
void** ptr, uint64_t size, TRITONSERVER_MemoryType* allocated_type,
bool allow_nonpinned_fallback, PinnedMemory* pinned_memory_buffer);
Status FreeInternal(void* ptr);
void AddPinnedMemoryBuffer(
const std::shared_ptr<PinnedMemory>& pinned_memory_buffer,
unsigned long node_mask);
static std::unique_ptr<PinnedMemoryManager> instance_;
static uint64_t pinned_memory_byte_size_;
std::mutex info_mtx_;
std::map<void*, std::pair<bool, PinnedMemory*>> memory_info_;
std::map<unsigned long, std::shared_ptr<PinnedMemory>> pinned_memory_buffers_;
};
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "rate_limiter.h"
#include <limits>
#include "triton/common/logging.h"
namespace triton { namespace core {
constexpr size_t MAX_PAYLOAD_BUCKET_COUNT = 1000;
//=========================================================================
// Core Implementation
//=========================================================================
Status
RateLimiter::Create(
const bool ignore_resources_and_priority,
const RateLimiter::ResourceMap& resource_map,
std::unique_ptr<RateLimiter>* rate_limiter)
{
std::unique_ptr<RateLimiter> local_rate_limiter(
new RateLimiter(ignore_resources_and_priority, resource_map));
*rate_limiter = std::move(local_rate_limiter);
return Status::Success;
}
Status
RateLimiter::RegisterModelInstance(
TritonModelInstance* triton_model_instance,
const RateLimiterConfig& rate_limiter_config)
{
{
std::lock_guard<std::mutex> lk1(model_ctx_mtx_);
std::lock_guard<std::mutex> lk2(model_instance_ctx_mtx_);
auto& model_context = model_contexts_[triton_model_instance->Model()];
auto& model_instances =
model_instance_ctxs_[triton_model_instance->Model()];
model_instances.push_back(
std::shared_ptr<ModelInstanceContext>(new ModelInstanceContext(
triton_model_instance, &model_context, rate_limiter_config,
[this](ModelInstanceContext* instance) { OnStage(instance); },
[this](ModelInstanceContext* instance) { OnRelease(instance); })));
model_context.AddAvailableInstance(model_instances.back().get());
model_context.AddSpecificRequestQueue();
if (!ignore_resources_and_priority_) {
resource_manager_->AddModelInstance(model_instances.back().get());
RETURN_IF_ERROR(resource_manager_->UpdateResourceLimits());
}
}
InitializePayloadQueues(triton_model_instance);
return Status::Success;
}
Status
RateLimiter::UnregisterModel(const TritonModel* model)
{
{
std::lock_guard<std::mutex> lk1(model_ctx_mtx_);
std::lock_guard<std::mutex> lk2(model_instance_ctx_mtx_);
auto& model_context = model_contexts_[model];
model_context.RequestRemoval();
for (const auto& instance : model_instance_ctxs_[model]) {
instance->WaitForRemoval();
if (!ignore_resources_and_priority_) {
resource_manager_->RemoveModelInstance(instance.get());
}
}
model_instance_ctxs_.erase(model);
model_contexts_.erase(model);
}
if (!ignore_resources_and_priority_) {
RETURN_IF_ERROR(resource_manager_->UpdateResourceLimits());
}
{
std::lock_guard<std::mutex> lk(payload_queues_mu_);
if (payload_queues_.find(model) != payload_queues_.end()) {
payload_queues_.erase(model);
}
}
return Status::Success;
}
bool
RateLimiter::PayloadSlotAvailable(const TritonModel* model)
{
bool result;
PayloadQueue* payload_queue = payload_queues_[model].get();
{
std::lock_guard<std::mutex> lk(payload_queue->mu_);
result = payload_queue->queue_->Size() <
2 * payload_queue->specific_queues_.size();
}
return result;
}
Status
RateLimiter::EnqueuePayload(
const TritonModel* model, std::shared_ptr<Payload> payload)
{
auto pinstance = payload->GetInstance();
if (payload_queues_.find(model) == payload_queues_.end()) {
LOG_INFO << "Should not print this ";
}
PayloadQueue* payload_queue = payload_queues_[model].get();
{
std::lock_guard<std::mutex> lk(payload_queue->mu_);
payload->SetState(Payload::State::REQUESTED);
if (ignore_resources_and_priority_) {
SchedulePayload(pinstance, payload_queue, payload);
}
}
if (ignore_resources_and_priority_) {
if (pinstance == nullptr) {
payload_queue->cv_.notify_one();
} else {
payload_queue->cv_.notify_all();
}
} else {
StandardScheduleFunc sched_func = [this, payload_queue,
payload](ModelInstanceContext* mi) {
{
std::lock_guard<std::mutex> lk(payload_queue->mu_);
this->SchedulePayload(mi->RawInstance(), payload_queue, payload);
}
auto cb = [mi]() { mi->Release(); };
payload->AddInternalReleaseCallback(cb);
if (mi->RawInstance() == nullptr) {
payload_queue->cv_.notify_one();
} else {
payload_queue->cv_.notify_all();
}
};
DeferPayloadSchedule(sched_func, model, payload->GetInstance());
}
return Status::Success;
}
void
RateLimiter::DequeuePayload(
std::deque<TritonModelInstance*>& instances,
std::shared_ptr<Payload>* payload)
{
payload->reset();
if (payload_queues_.find(instances[0]->Model()) == payload_queues_.end()) {
LOG_INFO << "Should not print this ";
}
PayloadQueue* payload_queue = payload_queues_[instances[0]->Model()].get();
std::vector<std::shared_ptr<Payload>> merged_payloads;
size_t instance_index = std::numeric_limits<std::size_t>::max();
{
std::unique_lock<std::mutex> lk(payload_queue->mu_);
payload_queue->cv_.wait(lk, [&instances, &instance_index, payload_queue]() {
bool empty = payload_queue->queue_->Empty();
if (empty) {
instance_index = 0;
for (const auto instance : instances) {
empty = payload_queue->specific_queues_[instance]->Empty();
if (empty) {
instance_index++;
} else {
break;
}
}
}
return !empty;
});
if (instance_index < instances.size()) {
TritonModelInstance* instance = instances[instance_index];
if (!payload_queue->specific_queues_[instance]->Empty()) {
payload_queue->specific_queues_[instance]->Dequeue(
payload, &merged_payloads);
}
} else {
payload_queue->queue_->Dequeue(payload, &merged_payloads);
}
}
for (auto& merge_payload : merged_payloads) {
PayloadRelease(merge_payload);
}
(*payload)->Callback();
if ((*payload)->GetInstance() == nullptr) {
(*payload)->SetInstance(instances.front());
instances.pop_front();
} else {
instances.erase(instances.begin() + instance_index);
}
}
std::shared_ptr<Payload>
RateLimiter::GetPayload(
const Payload::Operation op_type, TritonModelInstance* instance)
{
std::shared_ptr<Payload> payload;
if (max_payload_bucket_count_ > 0) {
std::lock_guard<std::mutex> lock(payload_mu_);
if (!payload_bucket_.empty()) {
payload = payload_bucket_.back();
payload_bucket_.pop_back();
}
if (payload.get() == nullptr && (!payloads_in_use_.empty())) {
// Just checking the front of the queue instead the entire queue for
// an available payload to save time.
if (payloads_in_use_.front().use_count() == 1) {
payload = payloads_in_use_.front();
payloads_in_use_.pop_front();
}
}
}
if (payload.get() == nullptr) {
payload.reset(new Payload());
}
payload->Reset(op_type, instance);
return payload;
}
void
RateLimiter::PayloadRelease(std::shared_ptr<Payload>& payload)
{
payload->OnRelease();
if (max_payload_bucket_count_ > 0) {
std::lock_guard<std::mutex> lock(payload_mu_);
if (payloads_in_use_.size() + payload_bucket_.size() <
max_payload_bucket_count_) {
// Release iff the payload shared_ptr is uniquely held.
if (payload.use_count() == 1) {
payload->Release();
payload_bucket_.push_back(std::move(payload));
return;
} else {
payloads_in_use_.push_back(std::move(payload));
}
}
}
}
RateLimiter::RateLimiter(
const bool ignore_resources_and_priority, const ResourceMap& resource_map)
: ignore_resources_and_priority_(ignore_resources_and_priority),
max_payload_bucket_count_(MAX_PAYLOAD_BUCKET_COUNT)
{
ResourceManager::Create(resource_map, &resource_manager_);
}
void
RateLimiter::InitializePayloadQueues(const TritonModelInstance* instance)
{
auto& config = instance->Model()->Config();
uint64_t max_queue_delay_microseconds;
if (config.has_sequence_batching()) {
const auto& batcher_config = config.sequence_batching();
if (batcher_config.has_oldest()) {
max_queue_delay_microseconds =
batcher_config.oldest().max_queue_delay_microseconds();
} else {
max_queue_delay_microseconds = 0;
}
} else if (config.has_dynamic_batching()) {
max_queue_delay_microseconds =
config.dynamic_batching().max_queue_delay_microseconds();
} else {
max_queue_delay_microseconds = 0;
}
{
std::lock_guard<std::mutex> lk(payload_queues_mu_);
if (payload_queues_.find(instance->Model()) == payload_queues_.end()) {
payload_queues_.emplace(
instance->Model(),
new PayloadQueue(
config.max_batch_size(), max_queue_delay_microseconds * 1000));
}
}
PayloadQueue* payload_queue = payload_queues_[instance->Model()].get();
if (payload_queue->specific_queues_.find(instance) ==
payload_queue->specific_queues_.end()) {
payload_queue->specific_queues_.emplace(
instance,
new InstanceQueue(
config.max_batch_size(), max_queue_delay_microseconds * 1000));
}
}
Status
RateLimiter::DeferPayloadSchedule(
const StandardScheduleFunc& OnSchedule, const TritonModel* model,
TritonModelInstance* triton_model_instance)
{
std::lock_guard<std::mutex> lk(model_ctx_mtx_);
auto itr = model_contexts_.find(model);
if (itr == model_contexts_.end()) {
return Status(
Status::Code::INTERNAL,
"Requested model is not yet registered with rate limiter");
}
if (itr->second.isRemovalInProgress()) {
return Status(
Status::Code::INTERNAL,
"New model requests can not be made to a model that is being "
"removed");
}
itr->second.EnqueueModelInstanceRequest(OnSchedule, triton_model_instance);
itr->second.StageInstanceIfAvailable(triton_model_instance);
return Status::Success;
}
void
RateLimiter::SchedulePayload(
TritonModelInstance* tmi, PayloadQueue* payload_queue,
const std::shared_ptr<Payload>& payload)
{
if (tmi == nullptr) {
payload_queue->queue_->Enqueue(payload);
} else {
payload_queue->specific_queues_[tmi]->Enqueue(payload);
}
payload->SetState(Payload::State::SCHEDULED);
}
void
RateLimiter::OnStage(ModelInstanceContext* instance)
{
{
std::lock_guard<std::recursive_mutex> lk(staged_instances_mtx_);
staged_instances_.push(instance);
}
AttemptAllocation();
}
void
RateLimiter::OnRelease(ModelInstanceContext* instance)
{
auto& model_context = model_contexts_[instance->RawInstance()->Model()];
model_context.AddAvailableInstance(instance);
resource_manager_->ReleaseResources(instance);
if (model_context.ContainsPendingRequests(instance->RawInstance()->Index())) {
model_context.StageInstanceIfAvailable(instance->RawInstance());
}
AttemptAllocation();
}
void
RateLimiter::AttemptAllocation()
{
std::lock_guard<std::recursive_mutex> lk(staged_instances_mtx_);
if (!staged_instances_.empty()) {
ModelInstanceContext* instance = staged_instances_.top();
if (resource_manager_->AllocateResources(instance)) {
staged_instances_.pop();
instance->Allocate();
}
}
}
//=========================================================================
// ModelContext Implementation
//=========================================================================
RateLimiter::ModelContext::ModelContext() : removal_in_progress_(false) {}
Status
RateLimiter::ModelContext::EnqueueModelInstanceRequest(
const StandardScheduleFunc& OnSchedule,
TritonModelInstance* triton_model_instance)
{
std::lock_guard<std::recursive_mutex> lk(sched_request_queue_mtx_);
if (triton_model_instance == nullptr) {
generic_sched_request_queue_.push(OnSchedule);
} else if (
(uint32_t)triton_model_instance->Index() <
specific_sched_request_queues_.size()) {
specific_sched_request_queues_[triton_model_instance->Index()].push(
OnSchedule);
} else {
return Status(
Status::Code::INTERNAL,
"expected instance index between 0 and " +
std::to_string(specific_sched_request_queues_.size()) + ", got " +
std::to_string(triton_model_instance->Index()));
}
return Status::Success;
}
void
RateLimiter::ModelContext::AddAvailableInstance(ModelInstanceContext* instance)
{
std::lock_guard<std::recursive_mutex> lk(avbl_instances_mtx_);
avbl_instances_.push(instance);
instance->MarkAvailable();
}
void
RateLimiter::ModelContext::StageInstanceIfAvailable(
TritonModelInstance* req_instance)
{
std::lock_guard<std::recursive_mutex> lk1(sched_request_queue_mtx_);
std::lock_guard<std::recursive_mutex> lk2(avbl_instances_mtx_);
PriorityQueue backup_queue;
while (!avbl_instances_.empty()) {
ModelInstanceContext* instance = avbl_instances_.top();
if ((req_instance != nullptr) &&
(instance->RawInstance() != req_instance)) {
backup_queue.push(instance);
avbl_instances_.pop();
continue;
}
if (!specific_sched_request_queues_[instance->RawInstance()->Index()]
.empty()) {
// Prioritize the specific requests for the available model
// instance highest priority.
const StandardScheduleFunc func =
specific_sched_request_queues_[instance->RawInstance()->Index()]
.front();
specific_sched_request_queues_[instance->RawInstance()->Index()].pop();
instance->Stage(func);
} else if (!generic_sched_request_queue_.empty()) {
// If request is for generic model instance then use the
// instance with the highest priority.
const StandardScheduleFunc func = generic_sched_request_queue_.front();
generic_sched_request_queue_.pop();
instance->Stage(func);
} else {
// If there are requests for a specific model instance then backup
// the model instance and keep searching through the available
// model instances. The prioritization will be taken care of in the
// staging priority queue.
backup_queue.push(instance);
}
avbl_instances_.pop();
}
// Restore the backup queue
if (!backup_queue.empty()) {
avbl_instances_.swap(backup_queue);
}
}
void
RateLimiter::ModelContext::AllocateInstanceIfAvailable()
{
std::lock_guard<std::recursive_mutex> lk1(sched_request_queue_mtx_);
std::lock_guard<std::recursive_mutex> lk2(avbl_instances_mtx_);
PriorityQueue backup_queue;
while (!avbl_instances_.empty()) {
ModelInstanceContext* instance = avbl_instances_.top();
if (!specific_sched_request_queues_[instance->RawInstance()->Index()]
.empty()) {
// Prioritize the specific requests for the available model
// instance highest priority.
const StandardScheduleFunc func =
specific_sched_request_queues_[instance->RawInstance()->Index()]
.front();
specific_sched_request_queues_[instance->RawInstance()->Index()].pop();
instance->DirectAllocate(func);
} else if (!generic_sched_request_queue_.empty()) {
// If request is for generic model instance then use the
// instance with the highest priority.
const StandardScheduleFunc func = generic_sched_request_queue_.front();
generic_sched_request_queue_.pop();
instance->DirectAllocate(func);
} else {
// If there are requests for a specific model instance then backup
// the model instance and keep searching through the available
// model instances. The prioritization will be taken care of in the
// staging priority queue.
backup_queue.push(instance);
}
avbl_instances_.pop();
}
// Restore the backup queue
if (!backup_queue.empty()) {
avbl_instances_.swap(backup_queue);
}
}
void
RateLimiter::ModelContext::AddSpecificRequestQueue()
{
std::lock_guard<std::recursive_mutex> lk(sched_request_queue_mtx_);
specific_sched_request_queues_.emplace_back();
}
bool
RateLimiter::ModelContext::ContainsPendingRequests(int index)
{
std::lock_guard<std::recursive_mutex> lk(sched_request_queue_mtx_);
return (generic_sched_request_queue_.size() != 0) ||
(specific_sched_request_queues_[index].size() != 0);
}
void
RateLimiter::ModelContext::RequestRemoval()
{
removal_in_progress_ = true;
}
//=========================================================================
// ModelInstanceContext Implementation
//=========================================================================
RateLimiter::ModelInstanceContext::ModelInstanceContext(
TritonModelInstance* triton_model_instance,
RateLimiter::ModelContext* model_context,
const RateLimiter::RateLimiterConfig& rate_limiter_config,
RateLimiter::StandardStageFunc OnStage,
RateLimiter::StandardReleaseFunc OnRelease)
: triton_model_instance_(triton_model_instance),
index_(triton_model_instance->Index()), model_context_(model_context),
rate_limiter_config_(rate_limiter_config), OnStage_(OnStage),
OnRelease_(OnRelease), exec_count_(0), state_(AVAILABLE)
{
}
void
RateLimiter::ModelInstanceContext::MarkAvailable()
{
std::lock_guard<std::mutex> lk(state_mtx_);
state_ = AVAILABLE;
}
Status
RateLimiter::ModelInstanceContext::Stage(StandardScheduleFunc OnSchedule)
{
{
std::lock_guard<std::mutex> lk(state_mtx_);
if (state_ != AVAILABLE) {
return Status(
Status::Code::INTERNAL,
"Can not stage a model instance that is not yet available");
}
state_ = STAGED;
OnSchedule_ = OnSchedule;
}
OnStage_(this);
return Status::Success;
}
Status
RateLimiter::ModelInstanceContext::Allocate()
{
{
std::lock_guard<std::mutex> lk(state_mtx_);
if (state_ != STAGED) {
return Status(
Status::Code::INTERNAL,
"Can not allocate a model instance that is not yet staged");
}
state_ = ALLOCATED;
}
OnSchedule_(this);
return Status::Success;
}
Status
RateLimiter::ModelInstanceContext::DirectAllocate(
StandardScheduleFunc OnSchedule)
{
{
std::lock_guard<std::mutex> lk(state_mtx_);
if (state_ != AVAILABLE) {
return Status(
Status::Code::INTERNAL,
"Can not allocate a model instance that is not yet available");
}
state_ = ALLOCATED;
}
OnSchedule(this);
return Status::Success;
}
void
RateLimiter::ModelInstanceContext::Release()
{
exec_count_++;
OnRelease_(this);
{
std::lock_guard<std::mutex> lk(state_mtx_);
if ((model_context_->isRemovalInProgress()) && (state_ == AVAILABLE) &&
(!model_context_->ContainsPendingRequests(index_))) {
state_ = REMOVED;
}
}
if (state_ == REMOVED) {
cv_.notify_all();
}
}
void
RateLimiter::ModelInstanceContext::RequestRemoval()
{
std::lock_guard<std::mutex> lk(state_mtx_);
if ((state_ == AVAILABLE) &&
(!model_context_->ContainsPendingRequests(index_))) {
state_ = REMOVED;
}
}
void
RateLimiter::ModelInstanceContext::WaitForRemoval()
{
if (!model_context_->isRemovalInProgress()) {
model_context_->RequestRemoval();
}
RequestRemoval();
// Wait for the instance to be removed
{
std::unique_lock<std::mutex> lk(state_mtx_);
cv_.wait(lk, [this] { return state_ == REMOVED; });
}
}
double
RateLimiter::ModelInstanceContext::ScaledPriority()
{
// TODO: Different schemes for the prioritization of
// model instance can be added here.
// The priority of instance is 1 by default. If specified
// as 0, the priority is still treated as 1.
auto priority = std::max(rate_limiter_config_.priority(), 1u);
return (exec_count_ * priority);
}
//=========================================================================
// ResourceManager Implementation
//=========================================================================
Status
RateLimiter::ResourceManager::Create(
const ResourceMap& resource_map,
std::unique_ptr<ResourceManager>* resource_manager)
{
std::unique_ptr<ResourceManager> local_resource_manager(
new ResourceManager(resource_map));
*resource_manager = std::move(local_resource_manager);
return Status::Success;
}
void
RateLimiter::ResourceManager::AddModelInstance(
const ModelInstanceContext* instance)
{
std::lock_guard<std::mutex> lk(model_resources_mtx_);
auto pr = model_resources_.emplace(std::make_pair(instance, ResourceMap()));
for (const auto& resource : instance->GetRateLimiterConfig()->resources()) {
if (resource.global()) {
(pr.first->second[GLOBAL_RESOURCE_KEY])[resource.name()] =
resource.count();
} else {
(pr.first->second[instance->RawInstance()->DeviceId()])[resource.name()] =
resource.count();
}
}
}
Status
RateLimiter::ResourceManager::RemoveModelInstance(
const ModelInstanceContext* instance)
{
std::lock_guard<std::mutex> lk(model_resources_mtx_);
const auto& itr = model_resources_.find(instance);
if (itr == model_resources_.end()) {
return Status(
Status::Code::INTERNAL, "Can not find the instance to remove");
}
model_resources_.erase(instance);
return Status::Success;
}
Status
RateLimiter::ResourceManager::UpdateResourceLimits()
{
std::lock_guard<std::mutex> lk1(max_resources_mtx_);
std::lock_guard<std::mutex> lk2(model_resources_mtx_);
max_resources_.clear();
// Obtain the maximum resource across all the instances
// and use it as the default available.
for (const auto& instance_resources : model_resources_) {
for (const auto& resource_device_map : instance_resources.second) {
auto ditr = max_resources_.find(resource_device_map.first);
if (ditr == max_resources_.end()) {
ditr =
max_resources_
.emplace(resource_device_map.first, resource_device_map.second)
.first;
} else {
for (const auto resource : resource_device_map.second) {
auto ritr = ditr->second.find(resource.first);
if (ritr == ditr->second.end()) {
ritr = ditr->second.emplace(resource.first, resource.second).first;
} else {
if (ritr->second < resource.second) {
ritr->second = resource.second;
}
}
}
}
}
}
if (!explicit_max_resources_.empty()) {
RETURN_IF_ERROR(ParseAndValidateExplicitResources());
}
RETURN_IF_ERROR(ValidateMaxResources());
if (LOG_VERBOSE_IS_ON(1)) {
std::string resource_map_str{"\nMax Resource Map===>\n"};
for (const auto& ditr : max_resources_) {
if (!ditr.second.empty()) {
std::string device_str{(ditr.first == GLOBAL_RESOURCE_KEY)
? "GLOBAL"
: std::to_string(ditr.first)};
resource_map_str += "\tDevice: " + device_str + "\n";
for (const auto& ritr : ditr.second) {
resource_map_str += "\t\tResource: " + ritr.first +
"\t Count: " + std::to_string(ritr.second) + "\n";
}
}
}
LOG_VERBOSE(1) << resource_map_str;
}
return Status::Success;
}
Status
RateLimiter::ResourceManager::ValidateMaxResources()
{
for (const auto& global_resource : max_resources_[GLOBAL_RESOURCE_KEY]) {
for (const auto& ditr : max_resources_) {
if (ditr.first != GLOBAL_RESOURCE_KEY) {
for (const auto& ritr : ditr.second) {
if (global_resource.first.compare(ritr.first) == 0) {
return Status(
Status::Code::INVALID_ARG,
(std::string("Resource \"") + ritr.first +
"\" is present as both global and device-specific resource in "
"the model configuration.")
.c_str());
}
}
}
}
}
return Status::Success;
}
Status
RateLimiter::ResourceManager::ParseAndValidateExplicitResources()
{
for (auto& ditr : max_resources_) {
for (auto& ritr : ditr.second) {
// If not specified explicitly, consider the resource to be unavailable.
size_t resource_count = 0;
if (ditr.first == GLOBAL_RESOURCE_KEY) {
// Ignore the device specification... will search for all resources in
// the map...
for (const auto& exp_ditr : explicit_max_resources_) {
for (const auto& exp_ritr : exp_ditr.second) {
if (ritr.first.compare(exp_ritr.first) == 0) {
if (resource_count < exp_ritr.second) {
resource_count = exp_ritr.second;
}
}
}
}
} else {
// Search only for the device specific or per-device resources...
// device-specific
for (const auto& exp_ritr : explicit_max_resources_[ditr.first]) {
if (ritr.first.compare(exp_ritr.first) == 0) {
if (resource_count < exp_ritr.second) {
resource_count = exp_ritr.second;
}
}
}
// per-device
for (const auto& exp_ritr :
explicit_max_resources_[PER_DEVICE_RESOURCE_KEY]) {
if (ritr.first.compare(exp_ritr.first) == 0) {
if (resource_count < exp_ritr.second) {
resource_count = exp_ritr.second;
}
}
}
}
if (resource_count < ritr.second) {
return Status(
Status::Code::INVALID_ARG,
(std::string("Resource count for \"") + ritr.first +
"\" is limited to " + std::to_string(resource_count) +
" which will prevent scheduling of one or more model "
"instances, the minimum required count is " +
std::to_string(ritr.second))
.c_str());
} else {
ritr.second = resource_count;
}
}
}
return Status::Success;
}
bool
RateLimiter::ResourceManager::AllocateResources(
const ModelInstanceContext* instance)
{
std::lock_guard<std::mutex> lk1(model_resources_mtx_);
std::lock_guard<std::mutex> lk2(allocated_resources_mtx_);
const auto& itr = model_resources_.find(instance);
if (itr == model_resources_.end()) {
return false;
} else {
// First pass to verify if resources are available
{
std::lock_guard<std::mutex> lk3(max_resources_mtx_);
for (const auto& ditr : itr->second) {
auto allocated_ditr = allocated_resources_.find(ditr.first);
if (allocated_ditr == allocated_resources_.end()) {
allocated_ditr =
allocated_resources_
.emplace(ditr.first, std::map<std::string, size_t>())
.first;
}
for (const auto& ritr : ditr.second) {
auto allocated_ritr = allocated_ditr->second.find(ritr.first);
if (allocated_ritr == allocated_ditr->second.end()) {
allocated_ritr =
allocated_ditr->second.emplace(ritr.first, 0).first;
}
if ((allocated_ritr->second + ritr.second) >
(max_resources_[ditr.first])[ritr.first]) {
return false;
}
}
}
}
// Second pass to actually allocate the resources
for (const auto& ditr : itr->second) {
for (const auto& ritr : ditr.second) {
(allocated_resources_[ditr.first])[ritr.first] += ritr.second;
}
}
}
return true;
}
Status
RateLimiter::ResourceManager::ReleaseResources(
const ModelInstanceContext* instance)
{
std::lock_guard<std::mutex> lk1(model_resources_mtx_);
std::lock_guard<std::mutex> lk2(allocated_resources_mtx_);
const auto& itr = model_resources_.find(instance);
if (itr == model_resources_.end()) {
return Status(
Status::Code::INTERNAL,
"Unable find the instance resources to release");
} else {
for (const auto& ditr : itr->second) {
for (const auto& ritr : ditr.second) {
(allocated_resources_[ditr.first])[ritr.first] -= ritr.second;
}
}
}
return Status::Success;
}
RateLimiter::ResourceManager::ResourceManager(const ResourceMap& resource_map)
: explicit_max_resources_(resource_map)
{
}
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <condition_variable>
#include <functional>
#include <mutex>
#include <queue>
#include <vector>
#include "backend_model.h"
#include "backend_model_instance.h"
#include "instance_queue.h"
#include "model_config.pb.h"
#include "payload.h"
#include "status.h"
namespace triton { namespace core {
// Limits the rate at which requests are dispatched to the model instances
class RateLimiter {
public:
using RateLimiterConfig = inference::ModelRateLimiter;
using ResourceMap = std::map<int, std::map<std::string, size_t>>;
enum RESOURCE_KIND_KEY {
// Key for holding global resources
GLOBAL_RESOURCE_KEY = -2,
// Key for holding resources per each device
PER_DEVICE_RESOURCE_KEY = -1
};
/// Creates a rate limiter object which will funnel the requests to
/// the model instances. A typical lifetime of the model instance within
/// RateLimiter transition from available -> staged -> allocated -> available.
/// The transition from available to staged occurs when a request is
/// registered for the model. Depending upon the resource availabilty and
/// priority, the RateLimiter will transition an instance to allocated state
/// at some point in the future. The staged state is skipped when
/// configured to ignore the resource constraints. The cycle in this case
/// will be available -> allocated -> available.
/// \param ignore_resources_and_priority Whether or not to ignore resource
/// constraints and cross-model priority. An available instance is directly
/// allocated when true.
/// \param resource_map The map to the available resource count provided
/// explicitly.
/// \return Status object indicating success or failure.
static Status Create(
const bool ignore_resources_and_priority, const ResourceMap& resource_map,
std::unique_ptr<RateLimiter>* rate_limiter);
/// Registers the model instance with the rate limiter.
/// \param instance The pointer to the TritonModelInstance object to register
/// with the rate limiter.
/// \param rate_limiter_config The rate limiter configuration associated with
/// the model instance.
/// \return Status object indicating success or failure.
Status RegisterModelInstance(
TritonModelInstance* instance,
const RateLimiterConfig& rate_limiter_config);
/// Remove model from the set of models being managed by the rate limiter.
/// \param model The pointer to TritonModel object to be removed.
/// \return Status object indicating success or failure.
Status UnregisterModel(const TritonModel* model);
/// Returns true if there is a payload slot available for the given model.
/// \param model The pointer to TritonModel object to be removed.
/// \return slot availability in boolean.
bool PayloadSlotAvailable(const TritonModel* model);
/// Enqueues the payload to rate limiter for scheduling on the given model.
/// \param model The pointer to TritonModel object to be removed.
/// \param payload The shared pointer to the payload object.
/// \return Status object indicating success or failure.
Status EnqueuePayload(
const TritonModel* model, std::shared_ptr<Payload> payload);
/// Returns the payload that has been scheduled for the given set of model
/// instances. Note that this call is blocking and depends upon the
/// availability of payloads in the rate limiter for the triton model
/// instance.
/// \param instance The pointers to TritonModelInstance objects whose
/// payload is being requested.
/// \param payload The shared pointer to the payload object.
void DequeuePayload(
std::deque<TritonModelInstance*>& instance,
std::shared_ptr<Payload>* payload);
/// Returns a new payload object.
/// \param op_type The operation type for the payload.
/// \param instance Optional field that providess the model instance that must
/// be used for the execution of the payload. Default is nullptr which allows
/// any model instance to execute the payload.
/// \return The shared pointer to a new payload object.
std::shared_ptr<Payload> GetPayload(
const Payload::Operation op_type,
TritonModelInstance* instance = nullptr);
/// Releases the given payload object back to the rate limiter.
/// \param payload The payload to release.
void PayloadRelease(std::shared_ptr<Payload>& payload);
private:
class ModelInstanceContext;
class ModelContext;
struct PayloadQueue;
using StandardReleaseFunc = std::function<void(ModelInstanceContext*)>;
using StandardScheduleFunc = std::function<void(ModelInstanceContext*)>;
using StandardStageFunc = std::function<void(ModelInstanceContext*)>;
// Holds the state of the model instance.
class ModelInstanceContext {
public:
friend class RateLimiter;
friend class ResourceManager;
enum State { AVAILABLE, STAGED, ALLOCATED, REMOVED };
void Release();
TritonModelInstance* RawInstance() const { return triton_model_instance_; }
private:
ModelInstanceContext(
TritonModelInstance* triton_model_instance, ModelContext* model_context,
const RateLimiterConfig& rate_limiter_config, StandardStageFunc OnStage,
StandardReleaseFunc OnRelease);
const RateLimiterConfig* GetRateLimiterConfig() const
{
return &rate_limiter_config_;
}
void MarkAvailable();
double ScaledPriority();
Status Stage(StandardScheduleFunc OnSchedule);
Status Allocate();
Status DirectAllocate(StandardScheduleFunc OnSchedule);
void RequestRemoval();
void WaitForRemoval();
TritonModelInstance* triton_model_instance_;
size_t index_;
ModelContext* model_context_;
RateLimiterConfig rate_limiter_config_;
StandardStageFunc OnStage_;
StandardReleaseFunc OnRelease_;
std::atomic<uint64_t> exec_count_;
State state_;
bool removal_in_progress_;
std::mutex state_mtx_;
StandardScheduleFunc OnSchedule_;
std::condition_variable cv_;
};
class ScaledPriorityComparator {
public:
bool operator()(ModelInstanceContext* a, ModelInstanceContext* b)
{
return a->ScaledPriority() > b->ScaledPriority();
}
};
using PriorityQueue = std::priority_queue<
ModelInstanceContext*, std::vector<ModelInstanceContext*>,
ScaledPriorityComparator>;
// Holds the active context to a model
class ModelContext {
public:
ModelContext();
Status EnqueueModelInstanceRequest(
const StandardScheduleFunc& OnSchedule,
TritonModelInstance* triton_model_instance);
void AddAvailableInstance(ModelInstanceContext* instance);
void StageInstanceIfAvailable(TritonModelInstance* triton_model_instance);
void AllocateInstanceIfAvailable();
void AddSpecificRequestQueue();
bool ContainsPendingRequests(int32_t index);
void RequestRemoval();
bool isRemovalInProgress() { return removal_in_progress_; }
private:
bool removal_in_progress_;
// Queue holding pending scheduling request
std::queue<StandardScheduleFunc> generic_sched_request_queue_;
std::vector<std::queue<StandardScheduleFunc>>
specific_sched_request_queues_;
std::recursive_mutex sched_request_queue_mtx_;
// The set of instances that are available at the moment
PriorityQueue avbl_instances_;
std::recursive_mutex avbl_instances_mtx_;
};
// Manages and keep track of resource allocation to the model instances.
class ResourceManager {
public:
static Status Create(
const ResourceMap& resource_map,
std::unique_ptr<ResourceManager>* resource_manager);
void AddModelInstance(const ModelInstanceContext* instance);
Status RemoveModelInstance(const ModelInstanceContext* instance);
Status UpdateResourceLimits();
bool AllocateResources(const ModelInstanceContext* instance);
Status ReleaseResources(const ModelInstanceContext* instance);
private:
ResourceManager(const ResourceMap& resource_map);
Status ValidateMaxResources();
Status ParseAndValidateExplicitResources();
ResourceMap explicit_max_resources_;
std::map<const ModelInstanceContext*, ResourceMap> model_resources_;
std::mutex model_resources_mtx_;
ResourceMap max_resources_;
std::mutex max_resources_mtx_;
ResourceMap allocated_resources_;
std::mutex allocated_resources_mtx_;
};
RateLimiter(
const bool ignore_resources_and_priority,
const ResourceMap& resource_map);
void InitializePayloadQueues(const TritonModelInstance* instance);
Status DeferPayloadSchedule(
const StandardScheduleFunc& OnSchedule, const TritonModel* model,
TritonModelInstance* instance = nullptr);
void OnStage(ModelInstanceContext* instance_ptr);
void OnRelease(ModelInstanceContext* instance_ptr);
void AttemptAllocation();
void SchedulePayload(
TritonModelInstance* tmi, PayloadQueue* payload_queue,
const std::shared_ptr<Payload>& payload);
bool ignore_resources_and_priority_;
// Instance context for the models
std::map<
const TritonModel*, std::vector<std::shared_ptr<ModelInstanceContext>>>
model_instance_ctxs_;
std::mutex model_instance_ctx_mtx_;
// Running context of the models
std::map<const TritonModel*, ModelContext> model_contexts_;
std::mutex model_ctx_mtx_;
// Holds the model instances that have been staged
PriorityQueue staged_instances_;
std::recursive_mutex staged_instances_mtx_;
// Manager to keep track of the resource allocations
std::unique_ptr<ResourceManager> resource_manager_;
// Mutex to serialize Payload [de]allocation
std::mutex payload_mu_;
// Mutex to serialize Payload Queues deallocation
std::mutex payload_queues_mu_;
// Keep some number of Payload objects for reuse to avoid the overhead
// of creating a Payload for every new request.
const size_t max_payload_bucket_count_;
std::vector<std::shared_ptr<Payload>> payload_bucket_;
std::deque<std::shared_ptr<Payload>> payloads_in_use_;
struct PayloadQueue {
explicit PayloadQueue(size_t max_batch_size, uint64_t max_queue_delay_ns)
{
queue_.reset(new InstanceQueue(max_batch_size, max_queue_delay_ns));
}
std::unique_ptr<InstanceQueue> queue_;
std::map<const TritonModelInstance*, std::unique_ptr<InstanceQueue>>
specific_queues_;
std::mutex mu_;
std::condition_variable cv_;
};
std::map<const TritonModel*, std::unique_ptr<PayloadQueue>> payload_queues_;
};
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "repo_agent.h"
#include <string>
#include "filesystem.h"
#include "shared_library.h"
#include "triton/common/logging.h"
#include "tritonserver_apis.h"
// For unknown reason, windows will not export the TRITONREPOAGENT_*
// functions declared with dllexport in tritonrepoagent.h. To get
// those functions exported it is (also?) necessary to mark the
// definitions in this file with dllexport as well.
#if defined(_MSC_VER)
#define TRITONAPI_DECLSPEC __declspec(dllexport)
#elif defined(__GNUC__)
#define TRITONAPI_DECLSPEC __attribute__((__visibility__("default")))
#else
#define TRITONAPI_DECLSPEC
#endif
namespace triton { namespace core {
std::string
TritonRepoAgentLibraryName(const std::string& agent_name)
{
#ifdef _WIN32
return std::string("tritonrepoagent_") + agent_name + ".dll";
#else
return std::string("libtritonrepoagent_") + agent_name + ".so";
#endif
}
std::string
TRITONREPOAGENT_ActionTypeString(const TRITONREPOAGENT_ActionType type)
{
switch (type) {
case TRITONREPOAGENT_ACTION_LOAD:
return "TRITONREPOAGENT_ACTION_LOAD";
case TRITONREPOAGENT_ACTION_LOAD_COMPLETE:
return "TRITONREPOAGENT_ACTION_LOAD_COMPLETE";
case TRITONREPOAGENT_ACTION_LOAD_FAIL:
return "TRITONREPOAGENT_ACTION_LOAD_FAIL";
case TRITONREPOAGENT_ACTION_UNLOAD:
return "TRITONREPOAGENT_ACTION_UNLOAD";
case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE:
return "TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE";
}
return "Unknown TRITONREPOAGENT_ActionType";
}
std::string
TRITONREPOAGENT_ArtifactTypeString(const TRITONREPOAGENT_ArtifactType type)
{
switch (type) {
case TRITONREPOAGENT_ARTIFACT_FILESYSTEM:
return "TRITONREPOAGENT_ARTIFACT_FILESYSTEM";
case TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM:
return "TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM";
}
return "Unknown TRITONREPOAGENT_ArtifactType";
}
//
// TritonRepoAgent
//
Status
TritonRepoAgent::Create(
const std::string& name, const std::string& libpath,
std::shared_ptr<TritonRepoAgent>* agent)
{
std::shared_ptr<TritonRepoAgent> lagent(new TritonRepoAgent(name));
{
std::unique_ptr<SharedLibrary> slib;
RETURN_IF_ERROR(SharedLibrary::Acquire(&slib));
RETURN_IF_ERROR(slib->OpenLibraryHandle(libpath, &lagent->dlhandle_));
RETURN_IF_ERROR(slib->GetEntrypoint(
lagent->dlhandle_, "TRITONREPOAGENT_Initialize", true /* optional */,
reinterpret_cast<void**>(&lagent->init_fn_)));
RETURN_IF_ERROR(slib->GetEntrypoint(
lagent->dlhandle_, "TRITONREPOAGENT_Finalize", true /* optional */,
reinterpret_cast<void**>(&lagent->fini_fn_)));
RETURN_IF_ERROR(slib->GetEntrypoint(
lagent->dlhandle_, "TRITONREPOAGENT_ModelInitialize",
true /* optional */,
reinterpret_cast<void**>(&lagent->model_init_fn_)));
RETURN_IF_ERROR(slib->GetEntrypoint(
lagent->dlhandle_, "TRITONREPOAGENT_ModelFinalize", true /* optional */,
reinterpret_cast<void**>(&lagent->model_fini_fn_)));
RETURN_IF_ERROR(slib->GetEntrypoint(
lagent->dlhandle_, "TRITONREPOAGENT_ModelAction", false /* optional */,
reinterpret_cast<void**>(&lagent->model_action_fn_)));
}
// Initialize if needed
if (lagent->init_fn_ != nullptr) {
RETURN_IF_TRITONSERVER_ERROR(lagent->init_fn_(
reinterpret_cast<TRITONREPOAGENT_Agent*>(lagent.get())));
}
*agent = std::move(lagent);
return Status::Success;
}
TritonRepoAgent::~TritonRepoAgent()
{
// Finalize if needed
if (fini_fn_ != nullptr) {
auto err = fini_fn_(reinterpret_cast<TRITONREPOAGENT_Agent*>(this));
if (err != nullptr) {
LOG_ERROR << "~TritonRepoAgent: "
<< Status(
TritonCodeToStatusCode(TRITONSERVER_ErrorCode(err)),
TRITONSERVER_ErrorMessage(err))
.AsString();
TRITONSERVER_ErrorDelete(err);
};
}
{
std::unique_ptr<SharedLibrary> slib;
LOG_STATUS_ERROR(SharedLibrary::Acquire(&slib), "~TritonRepoAgent");
LOG_STATUS_ERROR(slib->CloseLibraryHandle(dlhandle_), "~TritonRepoAgent");
}
}
//
// TritonRepoAgentModel
//
Status
TritonRepoAgentModel::Create(
const TRITONREPOAGENT_ArtifactType type, const std::string& location,
const inference::ModelConfig& config,
const std::shared_ptr<TritonRepoAgent>& agent,
const TritonRepoAgent::Parameters& agent_parameters,
std::unique_ptr<TritonRepoAgentModel>* agent_model)
{
std::unique_ptr<TritonRepoAgentModel> lagent_model(new TritonRepoAgentModel(
type, location, config, agent, agent_parameters));
if (agent->AgentModelInitFn() != nullptr) {
RETURN_IF_TRITONSERVER_ERROR(agent->AgentModelInitFn()(
reinterpret_cast<TRITONREPOAGENT_Agent*>(agent.get()),
reinterpret_cast<TRITONREPOAGENT_AgentModel*>(lagent_model.get())));
}
*agent_model = std::move(lagent_model);
return Status::Success;
}
TritonRepoAgentModel::~TritonRepoAgentModel()
{
// Need to ensure the proper lifecycle is informed
if (action_type_set_) {
switch (current_action_type_) {
case TRITONREPOAGENT_ACTION_LOAD:
LOG_TRITONSERVER_ERROR(
agent_->AgentModelActionFn()(
reinterpret_cast<TRITONREPOAGENT_Agent*>(agent_.get()),
reinterpret_cast<TRITONREPOAGENT_AgentModel*>(this),
TRITONREPOAGENT_ACTION_LOAD_FAIL),
"Inform TRITONREPOAGENT_ACTION_LOAD_FAIL");
break;
case TRITONREPOAGENT_ACTION_LOAD_COMPLETE:
LOG_TRITONSERVER_ERROR(
agent_->AgentModelActionFn()(
reinterpret_cast<TRITONREPOAGENT_Agent*>(agent_.get()),
reinterpret_cast<TRITONREPOAGENT_AgentModel*>(this),
TRITONREPOAGENT_ACTION_UNLOAD),
"Inform TRITONREPOAGENT_ACTION_UNLOAD");
// Fallthough is not yet an language feature until C++17
LOG_TRITONSERVER_ERROR(
agent_->AgentModelActionFn()(
reinterpret_cast<TRITONREPOAGENT_Agent*>(agent_.get()),
reinterpret_cast<TRITONREPOAGENT_AgentModel*>(this),
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE),
"Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE");
break;
case TRITONREPOAGENT_ACTION_UNLOAD:
LOG_TRITONSERVER_ERROR(
agent_->AgentModelActionFn()(
reinterpret_cast<TRITONREPOAGENT_Agent*>(agent_.get()),
reinterpret_cast<TRITONREPOAGENT_AgentModel*>(this),
TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE),
"Inform TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE");
break;
case TRITONREPOAGENT_ACTION_LOAD_FAIL:
case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE:
break;
}
}
if (agent_->AgentModelFiniFn() != nullptr) {
LOG_TRITONSERVER_ERROR(
agent_->AgentModelFiniFn()(
reinterpret_cast<TRITONREPOAGENT_Agent*>(agent_.get()),
reinterpret_cast<TRITONREPOAGENT_AgentModel*>(this)),
"~TritonRepoAgentModel");
}
if (!acquired_location_.empty()) {
DeleteMutableLocation();
}
}
Status
TritonRepoAgentModel::InvokeAgent(const TRITONREPOAGENT_ActionType action_type)
{
if ((!action_type_set_) && (action_type != TRITONREPOAGENT_ACTION_LOAD)) {
return Status(
Status::Code::INTERNAL,
"Unexpected lifecycle start state " +
TRITONREPOAGENT_ActionTypeString(action_type));
}
switch (action_type) {
case TRITONREPOAGENT_ACTION_LOAD:
if (action_type_set_) {
return Status(
Status::Code::INTERNAL,
"Unexpected lifecycle state transition from " +
TRITONREPOAGENT_ActionTypeString(current_action_type_) +
" to " + TRITONREPOAGENT_ActionTypeString(action_type));
}
break;
case TRITONREPOAGENT_ACTION_LOAD_COMPLETE:
case TRITONREPOAGENT_ACTION_LOAD_FAIL:
if (current_action_type_ != TRITONREPOAGENT_ACTION_LOAD) {
return Status(
Status::Code::INTERNAL,
"Unexpected lifecycle state transition from " +
TRITONREPOAGENT_ActionTypeString(current_action_type_) +
" to " + TRITONREPOAGENT_ActionTypeString(action_type));
}
break;
case TRITONREPOAGENT_ACTION_UNLOAD:
if (current_action_type_ != TRITONREPOAGENT_ACTION_LOAD_COMPLETE) {
return Status(
Status::Code::INTERNAL,
"Unexpected lifecycle state transition from " +
TRITONREPOAGENT_ActionTypeString(current_action_type_) +
" to " + TRITONREPOAGENT_ActionTypeString(action_type));
}
break;
case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE:
if (current_action_type_ != TRITONREPOAGENT_ACTION_UNLOAD) {
return Status(
Status::Code::INTERNAL,
"Unexpected lifecycle state transition from " +
TRITONREPOAGENT_ActionTypeString(current_action_type_) +
" to " + TRITONREPOAGENT_ActionTypeString(action_type));
}
break;
}
current_action_type_ = action_type;
action_type_set_ = true;
RETURN_IF_TRITONSERVER_ERROR(agent_->AgentModelActionFn()(
reinterpret_cast<TRITONREPOAGENT_Agent*>(agent_.get()),
reinterpret_cast<TRITONREPOAGENT_AgentModel*>(this), action_type));
return Status::Success;
}
Status
TritonRepoAgentModel::SetLocation(
const TRITONREPOAGENT_ArtifactType type, const std::string& location)
{
if (current_action_type_ != TRITONREPOAGENT_ACTION_LOAD) {
return Status(
Status::Code::INVALID_ARG,
"location can only be updated during TRITONREPOAGENT_ACTION_LOAD, "
"current action type is " +
(action_type_set_
? TRITONREPOAGENT_ActionTypeString(current_action_type_)
: "not set"));
}
type_ = type;
location_ = location;
return Status::Success;
}
Status
TritonRepoAgentModel::Location(
TRITONREPOAGENT_ArtifactType* type, const char** location)
{
if (location_.empty()) {
return Status(
Status::Code::INTERNAL, "Model repository location is not set");
}
*type = type_;
*location = location_.c_str();
return Status::Success;
}
Status
TritonRepoAgentModel::AcquireMutableLocation(
const TRITONREPOAGENT_ArtifactType type, const char** location)
{
if (type != TRITONREPOAGENT_ARTIFACT_FILESYSTEM) {
return Status(
Status::Code::INVALID_ARG,
"Unexpected artifact type, expects "
"'TRITONREPOAGENT_ARTIFACT_FILESYSTEM'");
}
if (acquired_location_.empty()) {
std::string lacquired_location;
RETURN_IF_ERROR(
MakeTemporaryDirectory(FileSystemType::LOCAL, &lacquired_location));
acquired_location_.swap(lacquired_location);
acquired_type_ = type;
}
*location = acquired_location_.c_str();
return Status::Success;
}
Status
TritonRepoAgentModel::DeleteMutableLocation()
{
if (acquired_location_.empty()) {
return Status(
Status::Code::UNAVAILABLE, "No mutable location to be deleted");
}
auto status = DeletePath(acquired_location_);
if (!status.IsOk()) {
LOG_ERROR << "Failed to delete previously acquired location '"
<< acquired_location_ << "': " << status.AsString();
}
acquired_location_.clear();
return Status::Success;
}
//
// TritonRepoAgentManager
//
TritonRepoAgentManager&
TritonRepoAgentManager::Singleton()
{
static TritonRepoAgentManager triton_repo_agent_manager;
return triton_repo_agent_manager;
}
Status
TritonRepoAgentManager::SetGlobalSearchPath(const std::string& path)
{
auto& singleton_manager = Singleton();
std::lock_guard<std::mutex> lock(singleton_manager.mu_);
singleton_manager.global_search_path_ = path;
return Status::Success;
}
Status
TritonRepoAgentManager::CreateAgent(
const std::string& agent_name, std::shared_ptr<TritonRepoAgent>* agent)
{
auto& singleton_manager = Singleton();
std::lock_guard<std::mutex> lock(singleton_manager.mu_);
// Get the path to the agent shared library. Search path is global
// agent directory. FIXME expose global path as Triton option
const std::vector<std::string> search_paths = {
JoinPath({singleton_manager.global_search_path_, agent_name})};
std::string agent_libname = TritonRepoAgentLibraryName(agent_name);
std::string libpath;
for (const auto& path : search_paths) {
const auto full_path = JoinPath({path, agent_libname});
bool exists = false;
RETURN_IF_ERROR(FileExists(full_path, &exists));
if (exists) {
libpath = full_path;
break;
}
}
if (libpath.empty()) {
return Status(
Status::Code::INVALID_ARG,
"unable to find '" + agent_libname + "' for repo agent '" + agent_name +
"', searched: " + singleton_manager.global_search_path_);
}
const auto& itr = singleton_manager.agent_map_.find(libpath);
if (itr != singleton_manager.agent_map_.end()) {
// Found in map. If the weak_ptr is still valid that means that
// there are other models using the agent and we just reuse that
// same agent. If the weak_ptr is not valid then agent has been
// unloaded so we need to remove the weak_ptr from the map and
// create the agent again.
*agent = itr->second.lock();
if (*agent != nullptr) {
return Status::Success;
}
singleton_manager.agent_map_.erase(itr);
}
RETURN_IF_ERROR(TritonRepoAgent::Create(agent_name, libpath, agent));
singleton_manager.agent_map_.insert({libpath, *agent});
return Status::Success;
}
Status
TritonRepoAgentManager::AgentState(
std::unique_ptr<std::unordered_map<std::string, std::string>>* agent_state)
{
auto& singleton_manager = Singleton();
std::lock_guard<std::mutex> lock(singleton_manager.mu_);
std::unique_ptr<std::unordered_map<std::string, std::string>> agent_state_map(
new std::unordered_map<std::string, std::string>);
for (const auto& agent_pair : singleton_manager.agent_map_) {
auto& libpath = agent_pair.first;
auto agent = agent_pair.second.lock();
if (agent != nullptr) {
agent_state_map->insert({agent->Name(), libpath});
}
}
*agent_state = std::move(agent_state_map);
return Status::Success;
}
extern "C" {
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ApiVersion(uint32_t* major, uint32_t* minor)
{
*major = TRITONREPOAGENT_API_VERSION_MAJOR;
*minor = TRITONREPOAGENT_API_VERSION_MINOR;
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelRepositoryLocation(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
TRITONREPOAGENT_ArtifactType* artifact_type, const char** location)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
RETURN_TRITONSERVER_ERROR_IF_ERROR(tam->Location(artifact_type, location));
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelRepositoryLocationAcquire(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
const TRITONREPOAGENT_ArtifactType artifact_type, const char** location)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
RETURN_TRITONSERVER_ERROR_IF_ERROR(
tam->AcquireMutableLocation(artifact_type, location));
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelRepositoryLocationRelease(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
const char* location)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
RETURN_TRITONSERVER_ERROR_IF_ERROR(tam->DeleteMutableLocation());
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelRepositoryUpdate(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
const TRITONREPOAGENT_ArtifactType artifact_type, const char* location)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
RETURN_TRITONSERVER_ERROR_IF_ERROR(tam->SetLocation(artifact_type, location));
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelParameterCount(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
uint32_t* count)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
*count = tam->AgentParameters().size();
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelParameter(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
const uint32_t index, const char** parameter_name,
const char** parameter_value)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
const auto& params = tam->AgentParameters();
if (index >= params.size()) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"index out of range for model parameters");
}
*parameter_name = params[index].first.c_str();
*parameter_value = params[index].second.c_str();
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelConfig(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
const uint32_t config_version, TRITONSERVER_Message** model_config)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
std::string model_config_json;
RETURN_TRITONSERVER_ERROR_IF_ERROR(
ModelConfigToJson(tam->Config(), config_version, &model_config_json));
return TRITONSERVER_MessageNewFromSerializedJson(
model_config, model_config_json.c_str(), model_config_json.length());
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelState(TRITONREPOAGENT_AgentModel* model, void** state)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
*state = tam->State();
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_ModelSetState(TRITONREPOAGENT_AgentModel* model, void* state)
{
TritonRepoAgentModel* tam = reinterpret_cast<TritonRepoAgentModel*>(model);
tam->SetState(state);
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_State(TRITONREPOAGENT_Agent* agent, void** state)
{
TritonRepoAgent* ta = reinterpret_cast<TritonRepoAgent*>(agent);
*state = ta->State();
return nullptr; // success
}
TRITONAPI_DECLSPEC TRITONSERVER_Error*
TRITONREPOAGENT_SetState(TRITONREPOAGENT_Agent* agent, void* state)
{
TritonRepoAgent* ta = reinterpret_cast<TritonRepoAgent*>(agent);
ta->SetState(state);
return nullptr; // success
}
} // extern C
}} // namespace triton::core
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "tritonserver_apis.h"
#include <memory>
#include <mutex>
#include <unordered_map>
#include <vector>
#include "constants.h"
#include "model_config_utils.h"
namespace triton { namespace core {
std::string TritonRepoAgentLibraryName(const std::string& agent_name);
std::string TRITONREPOAGENT_ActionTypeString(
const TRITONREPOAGENT_ActionType type);
std::string TRITONREPOAGENT_ArtifactTypeString(
const TRITONREPOAGENT_ArtifactType type);
class TritonRepoAgent {
public:
using Parameters = std::vector<std::pair<std::string, std::string>>;
typedef TRITONSERVER_Error* (*TritonRepoAgentInitFn_t)(
TRITONREPOAGENT_Agent* agent);
typedef TRITONSERVER_Error* (*TritonRepoAgentFiniFn_t)(
TRITONREPOAGENT_Agent* agent);
typedef TRITONSERVER_Error* (*TritonRepoAgentModelInitFn_t)(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model);
typedef TRITONSERVER_Error* (*TritonRepoAgentModelFiniFn_t)(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model);
typedef TRITONSERVER_Error* (*TritonRepoAgentModelActionFn_t)(
TRITONREPOAGENT_Agent* agent, TRITONREPOAGENT_AgentModel* model,
const TRITONREPOAGENT_ActionType action_type);
static Status Create(
const std::string& name, const std::string& libpath,
std::shared_ptr<TritonRepoAgent>* agent);
~TritonRepoAgent();
const std::string& Name() { return name_; }
void* State() { return state_; }
void SetState(void* state) { state_ = state; }
TritonRepoAgentModelActionFn_t AgentModelActionFn() const
{
return model_action_fn_;
}
TritonRepoAgentModelInitFn_t AgentModelInitFn() const
{
return model_init_fn_;
}
TritonRepoAgentModelFiniFn_t AgentModelFiniFn() const
{
return model_fini_fn_;
}
protected:
DISALLOW_COPY_AND_ASSIGN(TritonRepoAgent);
TritonRepoAgent(const std::string& name)
: name_(name), state_(nullptr), dlhandle_(nullptr), init_fn_(nullptr),
fini_fn_(nullptr), model_init_fn_(nullptr), model_fini_fn_(nullptr),
model_action_fn_(nullptr)
{
}
const std::string name_;
void* state_;
// dlopen / dlsym handles
void* dlhandle_;
TritonRepoAgentInitFn_t init_fn_;
TritonRepoAgentFiniFn_t fini_fn_;
TritonRepoAgentModelInitFn_t model_init_fn_;
TritonRepoAgentModelFiniFn_t model_fini_fn_;
TritonRepoAgentModelActionFn_t model_action_fn_;
};
class TritonRepoAgentModel {
public:
static Status Create(
const TRITONREPOAGENT_ArtifactType type, const std::string& location,
const inference::ModelConfig& config,
const std::shared_ptr<TritonRepoAgent>& agent,
const TritonRepoAgent::Parameters& agent_parameters,
std::unique_ptr<TritonRepoAgentModel>* agent_model);
~TritonRepoAgentModel();
void* State() { return state_; }
void SetState(void* state) { state_ = state; }
Status InvokeAgent(const TRITONREPOAGENT_ActionType action_type);
const TritonRepoAgent::Parameters& AgentParameters()
{
return agent_parameters_;
}
Status SetLocation(
const TRITONREPOAGENT_ArtifactType type, const std::string& location);
Status Location(TRITONREPOAGENT_ArtifactType* type, const char** location);
Status AcquireMutableLocation(
const TRITONREPOAGENT_ArtifactType type, const char** location);
Status DeleteMutableLocation();
const inference::ModelConfig Config() { return config_; }
private:
DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentModel);
TritonRepoAgentModel(
const TRITONREPOAGENT_ArtifactType type, const std::string& location,
const inference::ModelConfig& config,
const std::shared_ptr<TritonRepoAgent>& agent,
const TritonRepoAgent::Parameters& agent_parameters)
: state_(nullptr), config_(config), agent_(agent),
agent_parameters_(agent_parameters), type_(type), location_(location),
action_type_set_(false),
current_action_type_(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE)
{
}
void* state_;
const inference::ModelConfig config_;
const std::shared_ptr<TritonRepoAgent> agent_;
const TritonRepoAgent::Parameters agent_parameters_;
TRITONREPOAGENT_ArtifactType type_;
std::string location_;
TRITONREPOAGENT_ArtifactType acquired_type_;
std::string acquired_location_;
bool action_type_set_;
TRITONREPOAGENT_ActionType current_action_type_;
};
class TritonRepoAgentManager {
public:
static Status SetGlobalSearchPath(const std::string& path);
static Status CreateAgent(
const std::string& agent_name, std::shared_ptr<TritonRepoAgent>* agent);
static Status AgentState(
std::unique_ptr<std::unordered_map<std::string, std::string>>*
agent_state);
private:
DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentManager);
TritonRepoAgentManager()
: global_search_path_("/opt/tritonserver/repoagents"){};
static TritonRepoAgentManager& Singleton();
std::mutex mu_;
std::string global_search_path_;
std::unordered_map<std::string, std::weak_ptr<TritonRepoAgent>> agent_map_;
};
}} // namespace triton::core
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "tritonserver_apis.h"
namespace triton { namespace core {
//
// Implementation for TRITONSERVER_ResponseAllocator.
//
class ResponseAllocator {
public:
explicit ResponseAllocator(
TRITONSERVER_ResponseAllocatorAllocFn_t alloc_fn,
TRITONSERVER_ResponseAllocatorReleaseFn_t release_fn,
TRITONSERVER_ResponseAllocatorStartFn_t start_fn)
: alloc_fn_(alloc_fn), buffer_attributes_fn_(nullptr), query_fn_(nullptr),
release_fn_(release_fn), start_fn_(start_fn)
{
}
void SetQueryFunction(TRITONSERVER_ResponseAllocatorQueryFn_t query_fn)
{
query_fn_ = query_fn;
}
void SetBufferAttributesFunction(
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t buffer_attributes_fn)
{
buffer_attributes_fn_ = buffer_attributes_fn;
}
TRITONSERVER_ResponseAllocatorAllocFn_t AllocFn() const { return alloc_fn_; }
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t BufferAttributesFn() const
{
return buffer_attributes_fn_;
}
TRITONSERVER_ResponseAllocatorQueryFn_t QueryFn() const { return query_fn_; }
TRITONSERVER_ResponseAllocatorReleaseFn_t ReleaseFn() const
{
return release_fn_;
}
TRITONSERVER_ResponseAllocatorStartFn_t StartFn() const { return start_fn_; }
private:
TRITONSERVER_ResponseAllocatorAllocFn_t alloc_fn_;
TRITONSERVER_ResponseAllocatorBufferAttributesFn_t buffer_attributes_fn_;
TRITONSERVER_ResponseAllocatorQueryFn_t query_fn_;
TRITONSERVER_ResponseAllocatorReleaseFn_t release_fn_;
TRITONSERVER_ResponseAllocatorStartFn_t start_fn_;
};
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "response_cache.h"
#include "infer_stats.h"
#include "triton/common/logging.h"
namespace {
enum class ScopedTimerType { INSERTION, LOOKUP };
class ScopedTimer {
public:
explicit ScopedTimer(
triton::core::InferenceRequest& request, uint64_t& duration,
ScopedTimerType type)
: request_(request), duration_(duration), type_(type)
{
switch (type_) {
case ScopedTimerType::LOOKUP:
request_.CaptureCacheLookupStartNs();
break;
case ScopedTimerType::INSERTION:
request_.CaptureCacheInsertionStartNs();
break;
}
}
~ScopedTimer()
{
switch (type_) {
case ScopedTimerType::LOOKUP:
request_.CaptureCacheLookupEndNs();
duration_ +=
request_.CacheLookupEndNs() - request_.CacheLookupStartNs();
break;
case ScopedTimerType::INSERTION:
request_.CaptureCacheInsertionEndNs();
duration_ +=
request_.CacheInsertionEndNs() - request_.CacheInsertionStartNs();
break;
}
}
private:
triton::core::InferenceRequest& request_;
uint64_t& duration_;
ScopedTimerType type_;
};
std::string
PointerToString(void* ptr)
{
std::stringstream ss;
ss << ptr;
return ss.str();
}
} // namespace
namespace triton { namespace core {
Status
RequestResponseCache::Create(
uint64_t cache_size, std::unique_ptr<RequestResponseCache>* cache)
{
try {
cache->reset(new RequestResponseCache(cache_size));
}
catch (const std::exception& ex) {
return Status(
Status::Code::INTERNAL,
"Failed to initialize Response Cache: " + std::string(ex.what()));
}
return Status::Success;
}
RequestResponseCache::RequestResponseCache(const uint64_t size)
{
// Allocate buffer
buffer_ = malloc(size);
// Exit early if buffer allocation failed
if (buffer_ == nullptr) {
throw std::runtime_error("failed to allocate buffer");
}
// Create cache as managed buffer
managed_buffer_ = boost::interprocess::managed_external_buffer(
boost::interprocess::create_only_t{}, buffer_, size);
LOG_INFO << "Response Cache is created at '" << PointerToString(buffer_)
<< "' with size " << size;
}
RequestResponseCache::~RequestResponseCache()
{
// Deallocate each chunk from managed buffer
for (auto& iter : cache_) {
auto& entry = iter.second;
for (auto& output : entry.outputs_) {
if (output.buffer_ != nullptr) {
managed_buffer_.deallocate(output.buffer_);
}
}
}
// Validate we freed all underlying memory managed by cache
if (!managed_buffer_.all_memory_deallocated()) {
// Destructors can't throw exceptions
LOG_ERROR << "failed to free managed cache memory";
}
// Free total cache buffer
if (buffer_ != nullptr) {
free(buffer_);
}
}
Status
RequestResponseCache::Lookup(
InferenceResponse* const response, InferenceRequest* const request)
{
// Lock on cache lookup
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
if (request == nullptr) {
return Status(
Status::Code::INTERNAL, "Cache Lookup passed a nullptr request");
}
// Capture start latency now and end latency when timer goes out of scope
ScopedTimer timer(
*request, total_lookup_latency_ns_, ScopedTimerType::LOOKUP);
// Hash the request and set cache key if it hasn't already been set
if (!request->CacheKeyIsSet()) {
RETURN_IF_ERROR(HashAndSet(request));
}
const uint64_t key = request->CacheKey();
num_lookups_++;
LOG_VERBOSE(1) << request->LogRequest()
<< "Looking up key [" + std::to_string(key) + "] in cache.";
// Search cache for request hash key
auto iter = cache_.find(key);
if (iter == cache_.end()) {
num_misses_++;
LOG_VERBOSE(1) << request->LogRequest()
<< "MISS for key [" + std::to_string(key) + "] in cache.";
return Status(
Status::Code::INTERNAL,
request->LogRequest() + "key not found in cache");
}
// If find succeeds, it's a cache hit
num_hits_++;
LOG_VERBOSE(1) << request->LogRequest()
<< "HIT for key [" + std::to_string(key) + "] in cache.";
// Populate passed-in "response" from cache entry
auto entry = iter->second;
// Build InferenceResponse from CacheEntry
RETURN_IF_ERROR(BuildInferenceResponse(entry, response));
// Update this key to front of LRU list
UpdateLRU(iter);
LOG_VERBOSE(1) << request->LogRequest()
<< "Using cached response for key [" + std::to_string(key) +
"].";
return Status::Success;
}
Status
RequestResponseCache::Insert(
const InferenceResponse& response, InferenceRequest* const request)
{
// Lock on cache insertion
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
if (request == nullptr) {
return Status(
Status::Code::INTERNAL, "Cache Insert passed a nullptr request");
}
// Capture start latency now and end latency when timer goes out of scope
ScopedTimer timer(
*request, total_insertion_latency_ns_, ScopedTimerType::INSERTION);
// Hash the request and set cache key if it hasn't already been set
if (!request->CacheKeyIsSet()) {
RETURN_IF_ERROR(HashAndSet(request));
}
const uint64_t key = request->CacheKey();
// Exit early if key already exists in cache
auto iter = cache_.find(key);
if (iter != cache_.end()) {
return Status(
Status::Code::ALREADY_EXISTS, request->LogRequest() + "key [" +
std::to_string(key) +
"] already exists in cache");
}
// Construct cache entry from response
auto entry = CacheEntry();
RETURN_IF_ERROR(BuildCacheEntry(response, &entry));
// Insert entry into cache
LOG_VERBOSE(1) << request->LogRequest()
<< "Inserting key [" + std::to_string(key) + "] into cache.";
auto cache_pair = cache_.insert({key, entry});
// Exit early if cache insertion failed
if (!cache_pair.second) {
LOG_ERROR << request->LogRequest() << "Failed to insert key into map.";
return Status(
Status::Code::INTERNAL,
request->LogRequest() + "Cache insertion failed");
}
// Update LRU with new cache entry
auto cache_iter = cache_pair.first;
UpdateLRU(cache_iter);
return Status::Success;
}
// LRU
Status
RequestResponseCache::Evict()
{
// Lock on cache eviction
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
// Nothing to evict if cache is empty
if (NumEntries() == 0) {
return Status(Status::Code::INTERNAL, "Cache is empty, nothing to evict.");
}
// Least recently used key in back of LRU list
uint64_t lru_key = lru_.back();
LOG_VERBOSE(1) << "Evicting key [" + std::to_string(lru_key) +
"] from cache.";
// Find cache entry for least recently used key
auto iter = cache_.find(lru_key);
// Error check if key isn't in cache, but this shouldn't happen in evict
// and probably indicates a bug
if (iter == cache_.end()) {
return Status(
Status::Code::INTERNAL,
"key [" + std::to_string(lru_key) +
"] not found in cache during eviction: this indicates a bug in the "
"code");
}
// Get size of cache entry being evicted to update available size
auto entry = iter->second;
// Free managed memory used in cache entry's outputs
for (auto& output : entry.outputs_) {
// Lock on buffer deallocation
std::lock_guard<std::recursive_mutex> lk(buffer_mtx_);
managed_buffer_.deallocate(output.buffer_);
}
// Remove LRU entry from cache
cache_.erase(lru_key);
// Remove LRU key from LRU list
lru_.pop_back();
// Increment number of evictions
num_evictions_++;
return Status::Success;
}
// Helpers
void
RequestResponseCache::UpdateLRU(
std::unordered_map<uint64_t, CacheEntry>::iterator& cache_iter)
{
// Lock on cache update
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
const auto& key = cache_iter->first;
auto& cache_entry = cache_iter->second;
// Remove key from LRU list if it was already in there
auto lru_iter = std::find(lru_.begin(), lru_.end(), key);
if (lru_iter != lru_.end()) {
lru_.erase(lru_iter);
}
// Add key to front of LRU list since it's most recently used
lru_.push_front(key);
// Set CacheEntry LRU iterator to new LRU key location
cache_entry.lru_iter_ = lru_.begin();
}
Status
RequestResponseCache::BuildCacheEntry(
const InferenceResponse& response, CacheEntry* const entry)
{
// Build cache entry data from response outputs
for (const auto& response_output : response.Outputs()) {
auto cache_output = Output();
// Fetch output buffer details
const void* response_buffer = nullptr;
size_t response_byte_size = 0;
TRITONSERVER_MemoryType response_memory_type;
int64_t response_memory_type_id;
void* userp;
RETURN_IF_ERROR(response_output.DataBuffer(
&response_buffer, &response_byte_size, &response_memory_type,
&response_memory_type_id, &userp));
// TODO: Handle other memory types
if (response_memory_type != TRITONSERVER_MEMORY_CPU &&
response_memory_type != TRITONSERVER_MEMORY_CPU_PINNED) {
return Status(
Status::Code::INTERNAL,
"Only input buffers in CPU memory are allowed in cache currently");
}
// Exit early if response buffer from output is invalid
if (response_buffer == nullptr) {
return Status(
Status::Code::INTERNAL, "Response buffer from output was nullptr");
}
// Lock on managed buffer references
{
std::lock_guard<std::recursive_mutex> lk(buffer_mtx_);
// Exit early if cache entry will be larger than available cache size
if (response_byte_size > managed_buffer_.get_size()) {
return Status(
Status::Code::INTERNAL,
"Cache entry is larger than total cache size");
}
// If cache doesn't have enough space, evict until enough space available
// NOTE: FreeBytes() doesn't account for allocator overhead so allocation
// may fail even if response_byte_size is less than FreeBytes()
while (response_byte_size > FreeBytes()) {
LOG_VERBOSE(1) << "EVICT: Response larger than remaining available "
"memory, attempting to evict from cache.";
RETURN_IF_ERROR(Evict());
}
// Attempt to allocate buffer until success or eviction from cache fails
while (cache_output.buffer_ == nullptr) {
// Allocate buffer for response output in cache entry
cache_output.buffer_ =
managed_buffer_.allocate(response_byte_size, std::nothrow_t{});
// Attempt to evict if allocation fails
if (cache_output.buffer_ == nullptr) {
LOG_VERBOSE(1) << "FAILED to allocate buffer in cache. Attempting to "
"evict an entry.";
// Exit out if Eviction fails
RETURN_IF_ERROR(Evict());
}
}
// Copy data from response buffer to cache entry output buffer
// TODO: Handle other memory types
std::memcpy(cache_output.buffer_, response_buffer, response_byte_size);
// Set output metadata
cache_output.name_ = response_output.Name();
cache_output.dtype_ = response_output.DType();
cache_output.shape_ = response_output.Shape();
cache_output.buffer_size_ = static_cast<uint64_t>(response_byte_size);
}
// Add each output to cache entry
entry->outputs_.push_back(cache_output);
}
return Status::Success;
}
Status
RequestResponseCache::BuildInferenceResponse(
const CacheEntry& entry, InferenceResponse* const response)
{
if (response == nullptr) {
return Status(Status::Code::INTERNAL, "invalid response ptr passed in");
}
// Lock on cache references
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
// Inference response outputs should be empty so we can append to them
if (response->Outputs().size() != 0) {
return Status(
Status::Code::INTERNAL,
"InferenceResponse already contains some outputs");
}
for (auto& cache_output : entry.outputs_) {
InferenceResponse::Output* response_output = nullptr;
RETURN_IF_ERROR(response->AddOutput(
cache_output.name_, cache_output.dtype_, cache_output.shape_,
&response_output));
if (response_output == nullptr) {
return Status(
Status::Code::INTERNAL,
"InferenceResponse::Output pointer as nullptr");
}
TRITONSERVER_MemoryType memory_type = TRITONSERVER_MEMORY_CPU;
int64_t memory_type_id = 0;
// Allocate buffer for inference response
void* buffer;
RETURN_IF_ERROR(response_output->AllocateDataBuffer(
&buffer, cache_output.buffer_size_, &memory_type, &memory_type_id));
// TODO: Handle other memory types
if (memory_type != TRITONSERVER_MEMORY_CPU &&
memory_type != TRITONSERVER_MEMORY_CPU_PINNED) {
return Status(
Status::Code::INTERNAL,
"Only input buffers in CPU memory are allowed in cache currently");
}
if (buffer == nullptr) {
return Status(
Status::Code::INTERNAL, "failed to allocate buffer for output '" +
cache_output.name_ + "'");
}
// Copy cached output buffer to allocated response output buffer
std::memcpy(buffer, cache_output.buffer_, cache_output.buffer_size_);
// TODO: Add field to InferenceResponse to indicate this was from cache
// response.cached = true;
}
}
return Status::Success;
}
Status
RequestResponseCache::HashInputBuffers(
const InferenceRequest::Input* input, size_t* seed)
{
// Iterate over each data buffer in input in case of non-contiguous memory
for (size_t idx = 0; idx < input->DataBufferCount(); ++idx) {
const void* src_buffer;
size_t src_byte_size;
TRITONSERVER_MemoryType src_memory_type;
int64_t src_memory_type_id;
RETURN_IF_ERROR(input->DataBuffer(
idx, &src_buffer, &src_byte_size, &src_memory_type,
&src_memory_type_id));
// TODO: Handle other memory types
if (src_memory_type != TRITONSERVER_MEMORY_CPU &&
src_memory_type != TRITONSERVER_MEMORY_CPU_PINNED) {
return Status(
Status::Code::INTERNAL,
"Only input buffers in CPU memory are allowed in cache currently");
}
// Add each byte of input buffer chunk to hash
const unsigned char* tmp = static_cast<const unsigned char*>(src_buffer);
for (uint64_t byte = 0; byte < src_byte_size; byte++) {
boost::hash_combine(*seed, tmp[byte]);
}
}
return Status::Success;
}
Status
RequestResponseCache::HashInputs(const InferenceRequest& request, size_t* seed)
{
const auto& inputs = request.ImmutableInputs();
// Convert inputs to ordered map for consistency in hashing
// inputs sorted by key (input) name
std::map<std::string, InferenceRequest::Input*> ordered_inputs(
inputs.begin(), inputs.end());
for (const auto& input : ordered_inputs) {
// Add input name to hash
boost::hash_combine(*seed, input.second->Name());
// Fetch input buffer for hashing raw data
RETURN_IF_ERROR(HashInputBuffers(input.second, seed));
}
return Status::Success;
}
Status
RequestResponseCache::Hash(const InferenceRequest& request, uint64_t* key)
{
std::size_t seed = 0;
// Add request model name to hash
boost::hash_combine(seed, request.ModelName());
// Add request model version to hash
boost::hash_combine(seed, request.ActualModelVersion());
RETURN_IF_ERROR(HashInputs(request, &seed));
*key = static_cast<uint64_t>(seed);
return Status::Success;
}
Status
RequestResponseCache::HashAndSet(InferenceRequest* const request)
{
uint64_t key = 0;
RETURN_IF_ERROR(Hash(*request, &key));
request->SetCacheKey(key);
return Status::Success;
}
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <list>
#include <string>
#include <unordered_map>
#include "infer_request.h"
#include "infer_response.h"
#include "model.h"
#include "status.h"
#include <boost/functional/hash.hpp>
#include <boost/interprocess/managed_external_buffer.hpp>
namespace triton { namespace core {
// Assuming CPU memory only for now
struct Output {
// Output tensor data buffer
void* buffer_;
// Size of "buffer" above
uint64_t buffer_size_ = 0;
// Name of the output
std::string name_;
// Datatype of the output
inference::DataType dtype_;
// Shape of the output
std::vector<int64_t> shape_;
};
struct CacheEntry {
explicit CacheEntry() {}
// Point to key in LRU list for maintaining LRU order
std::list<uint64_t>::iterator lru_iter_;
// each output buffer = managed_buffer.allocate(size, ...)
std::vector<Output> outputs_;
};
class RequestResponseCache {
public:
~RequestResponseCache();
// Create the request/response cache object
static Status Create(
uint64_t cache_size, std::unique_ptr<RequestResponseCache>* cache);
// Hash inference request for cache access and store it in "request" object.
// This will also be called internally in Lookup/Insert if the request hasn't
// already stored it's hash. It is up to the user to update the hash in the
// request if modifying any hashed fields of the request object after storing.
// Return Status object indicating success or failure.
Status HashAndSet(InferenceRequest* const request);
// Lookup 'request' hash in cache and return the inference response in
// 'response' on cache hit or nullptr on cache miss
// Return Status object indicating success or failure.
Status Lookup(
InferenceResponse* const response, InferenceRequest* const request);
// Insert response into cache, evict entries to make space if necessary
// Return Status object indicating success or failure.
Status Insert(
const InferenceResponse& response, InferenceRequest* const request);
// Evict entry from cache based on policy
// Return Status object indicating success or failure.
Status Evict();
// Returns number of items in cache
size_t NumEntries()
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return cache_.size();
}
// Returns number of items evicted in cache lifespan
size_t NumEvictions()
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return num_evictions_;
}
// Returns number of lookups in cache lifespan, should sum to hits + misses
size_t NumLookups()
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return num_lookups_;
}
// Returns number of cache hits in cache lifespan
size_t NumHits()
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return num_hits_;
}
// Returns number of cache hits in cache lifespan
size_t NumMisses()
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return num_misses_;
}
// Returns the total lookup latency (nanoseconds) of all lookups in cache
// lifespan
uint64_t TotalLookupLatencyNs()
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return total_lookup_latency_ns_;
}
uint64_t TotalInsertionLatencyNs()
{
std::lock_guard<std::recursive_mutex> lk(cache_mtx_);
return total_insertion_latency_ns_;
}
// Returns total number of bytes allocated for cache
size_t TotalBytes()
{
std::lock_guard<std::recursive_mutex> lk(buffer_mtx_);
return managed_buffer_.get_size();
}
// Returns number of free bytes in cache
size_t FreeBytes()
{
std::lock_guard<std::recursive_mutex> lk(buffer_mtx_);
return managed_buffer_.get_free_memory();
}
// Returns number of bytes in use by cache
size_t AllocatedBytes()
{
std::lock_guard<std::recursive_mutex> lk(buffer_mtx_);
return managed_buffer_.get_size() - managed_buffer_.get_free_memory();
}
// Returns fraction of bytes allocated over total cache size between [0, 1]
double TotalUtilization()
{
std::lock_guard<std::recursive_mutex> lk(buffer_mtx_);
return static_cast<double>(AllocatedBytes()) /
static_cast<double>(TotalBytes());
}
private:
explicit RequestResponseCache(const uint64_t cache_size);
// Update LRU ordering on lookup
void UpdateLRU(std::unordered_map<uint64_t, CacheEntry>::iterator&);
// Build CacheEntry from InferenceResponse
Status BuildCacheEntry(
const InferenceResponse& response, CacheEntry* const entry);
// Build InferenceResponse from CacheEntry
Status BuildInferenceResponse(
const CacheEntry& entry, InferenceResponse* const response);
// Helper function to hash data buffers used by "input"
Status HashInputBuffers(const InferenceRequest::Input* input, size_t* seed);
// Helper function to hash each input in "request"
Status HashInputs(const InferenceRequest& request, size_t* seed);
// Helper function to hash request and store it in "key"
Status Hash(const InferenceRequest& request, uint64_t* key);
// Cache buffer
void* buffer_;
// Managed buffer
boost::interprocess::managed_external_buffer managed_buffer_;
// key -> CacheEntry containing values and list iterator for LRU management
std::unordered_map<uint64_t, CacheEntry> cache_;
// List of keys sorted from most to least recently used
std::list<uint64_t> lru_;
// Cache metrics
size_t num_evictions_ = 0;
size_t num_lookups_ = 0;
size_t num_hits_ = 0;
size_t num_misses_ = 0;
uint64_t total_lookup_latency_ns_ = 0;
uint64_t total_insertion_latency_ns_ = 0;
// Mutex for buffer synchronization
std::recursive_mutex buffer_mtx_;
// Mutex for cache synchronization
std::recursive_mutex cache_mtx_;
};
}} // namespace triton::core
// Copyright (c) 2018-2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <functional>
#include "infer_request.h"
#include "status.h"
namespace triton { namespace core {
// Scheduler interface.
class Scheduler {
public:
virtual ~Scheduler() {}
// The prototype for the initialization function that will be called
// by the "standard" schedulers created based on a model's
// scheduling_choice settings. The init function is called once by
// the runner that will later execute requests for 'runner_idx'. A
// non-OK error status indicates an initialization error that
// prevents scheduler from using the runner.
using StandardInitFunc = std::function<Status(uint32_t runner_idx)>;
// The prototype for the warmup function that will be called by the
// "standard" schedulers created based on a model's
// scheduling_choice settings. The warmup function is called once by
// the runner that will later execute requests for 'runner_idx'. A
// non-OK error status indicates an error that prevents scheduler
// from sending warmup requests to the runner.
using StandardWarmupFunc = std::function<Status(uint32_t runner_idx)>;
// The prototype for the run function that will be called by the
// "standard" schedulers created based on a model's
// scheduling_choice settings. The run function must accept a
// 'runner_idx' indicating which runner should execute the
// 'requests'. Ownership of the 'requests' is transferred to the
// runner which is responsible for generating responses and
// releasing the requests.
using StandardRunFunc = std::function<void(
uint32_t runner_idx,
std::vector<std::unique_ptr<InferenceRequest>>&& requests)>;
// Enqueue a request with the scheduler. If Status::Success is returned
// then the backend has taken ownership of the request object and so
// 'request' will be nullptr. If non-success is returned then the
// caller still retains ownership of 'request'.
virtual Status Enqueue(std::unique_ptr<InferenceRequest>& request) = 0;
// Return the number of in-flight inferences tracked by the scheduler.
virtual size_t InflightInferenceCount() = 0;
// Instruct the scheduler to stop processing future requests unless they are
// considered as in-flight.
virtual void Stop() = 0;
};
}} // namespace triton::core
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment