// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved. // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions // are met: // * Redistributions of source code must retain the above copyright // notice, this list of conditions and the following disclaimer. // * Redistributions in binary form must reproduce the above copyright // notice, this list of conditions and the following disclaimer in the // documentation and/or other materials provided with the distribution. // * Neither the name of NVIDIA CORPORATION nor the names of its // contributors may be used to endorse or promote products derived // from this software without specific prior written permission. // // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR // PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR // CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, // EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, // PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR // PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY // OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT // (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE // OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #include "model.h" #include #include #include "constants.h" #include "filesystem.h" #include "infer_request.h" #include "model_config_utils.h" #include "triton/common/logging.h" namespace triton { namespace core { Status Model::GetInput( const std::string& name, const inference::ModelInput** input) const { const auto itr = input_map_.find(name); if (itr == input_map_.end()) { return Status( Status::Code::INVALID_ARG, "unexpected inference input '" + name + "' for model '" + Name() + "'"); } *input = &itr->second; return Status::Success; } Status Model::GetOutput( const std::string& name, const inference::ModelOutput** output) const { const auto itr = output_map_.find(name); if (itr == output_map_.end()) { return Status( Status::Code::INVALID_ARG, "unexpected inference output '" + name + "' for model '" + Name() + "'"); } *output = &itr->second; return Status::Success; } Status Model::SetModelConfig(const inference::ModelConfig& config) { config_ = config; set_model_config_ = true; return Status::Success; } Status Model::SetScheduler(std::unique_ptr scheduler) { if (scheduler_ != nullptr) { return Status( Status::Code::INTERNAL, "Attempt to change scheduler not allowed"); } scheduler_ = std::move(scheduler); return Status::Success; } Status Model::Init(const bool is_config_provided) { if (!set_model_config_ && !is_config_provided) { return Status( Status::Code::NOT_FOUND, "model configuration is not provided for model '" + Name() + "'"); } RETURN_IF_ERROR(ValidateModelConfig(config_, min_compute_capability_)); RETURN_IF_ERROR(ValidateModelIOConfig(config_)); // Initialize the input map for (const auto& io : config_.input()) { input_map_.insert(std::make_pair(io.name(), io)); if (!io.optional()) { ++required_input_count_; } } // Initialize the output map and label provider for each output label_provider_ = std::make_shared(); for (const auto& io : config_.output()) { output_map_.insert(std::make_pair(io.name(), io)); if (!io.label_filename().empty()) { const auto label_path = JoinPath({model_dir_, io.label_filename()}); RETURN_IF_ERROR(label_provider_->AddLabels(io.name(), label_path)); } } if (config_.has_dynamic_batching()) { default_priority_level_ = config_.dynamic_batching().default_priority_level(); max_priority_level_ = config_.dynamic_batching().priority_levels(); } else if (config_.has_ensemble_scheduling()) { // For ensemble, allow any priority level to pass through default_priority_level_ = 0; max_priority_level_ = UINT32_MAX; } else { default_priority_level_ = 0; max_priority_level_ = 0; } return Status::Success; } }} // namespace triton::core