Commit b30f3cdb authored by xiabo's avatar xiabo
Browse files

添加下载的代码

parent e38ee081
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "label_provider.h"
#include <iostream>
#include <iterator>
#include <sstream>
#include "filesystem.h"
namespace triton { namespace core {
const std::string&
LabelProvider::GetLabel(const std::string& name, size_t index) const
{
static const std::string not_found;
auto itr = label_map_.find(name);
if (itr == label_map_.end()) {
return not_found;
}
if (itr->second.size() <= index) {
return not_found;
}
return itr->second[index];
}
Status
LabelProvider::AddLabels(const std::string& name, const std::string& filepath)
{
std::string label_file_contents;
RETURN_IF_ERROR(ReadTextFile(filepath, &label_file_contents));
auto p = label_map_.insert(std::make_pair(name, std::vector<std::string>()));
if (!p.second) {
return Status(
Status::Code::INTERNAL, "multiple label files for '" + name + "'");
}
auto itr = p.first;
std::istringstream label_file_stream(label_file_contents);
std::string line;
while (std::getline(label_file_stream, line)) {
itr->second.push_back(line);
}
return Status::Success;
}
const std::vector<std::string>&
LabelProvider::GetLabels(const std::string& name)
{
static const std::vector<std::string> not_found;
auto itr = label_map_.find(name);
if (itr == label_map_.end()) {
return not_found;
}
return itr->second;
}
Status
LabelProvider::AddLabels(
const std::string& name, const std::vector<std::string>& labels)
{
label_map_.emplace(name, labels);
return Status::Success;
}
}} // namespace triton::core
// Copyright (c) 2018-2019, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include <unordered_map>
#include <vector>
#include "constants.h"
#include "status.h"
namespace triton { namespace core {
// Provides classification labels.
class LabelProvider {
public:
LabelProvider() = default;
// Return the label associated with 'name' for a given
// 'index'. Return empty string if no label is available.
const std::string& GetLabel(const std::string& name, size_t index) const;
// Associate with 'name' a set of labels initialized from a given
// 'filepath'. Within the file each label is specified on its own
// line. The first label (line 0) is the index-0 label, the second
// label (line 1) is the index-1 label, etc.
Status AddLabels(const std::string& name, const std::string& filepath);
// Return the labels associated with 'name'. Return empty vector if no labels
// are available.
const std::vector<std::string>& GetLabels(const std::string& name);
// Associate with 'name' a set of 'labels'
Status AddLabels(
const std::string& name, const std::vector<std::string>& labels);
private:
DISALLOW_COPY_AND_ASSIGN(LabelProvider);
std::unordered_map<std::string, std::vector<std::string>> label_map_;
};
}} // namespace triton::core
# Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
TRITONSERVER_*;
TRITONBACKEND_*;
TRITONREPOAGENT_*;
local: *;
};
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "memory.h"
#include "pinned_memory_manager.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#include "cuda_memory_manager.h"
#endif // TRITON_ENABLE_GPU
namespace triton { namespace core {
//
// MemoryReference
//
MemoryReference::MemoryReference() : Memory() {}
const char*
MemoryReference::BufferAt(
size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type,
int64_t* memory_type_id) const
{
if (idx >= buffer_.size()) {
*byte_size = 0;
*memory_type = TRITONSERVER_MEMORY_CPU;
*memory_type_id = 0;
return nullptr;
}
*memory_type = buffer_[idx].buffer_attributes_.MemoryType();
*memory_type_id = buffer_[idx].buffer_attributes_.MemoryTypeId();
*byte_size = buffer_[idx].buffer_attributes_.ByteSize();
return buffer_[idx].buffer_;
}
const char*
MemoryReference::BufferAt(size_t idx, BufferAttributes** buffer_attributes)
{
if (idx >= buffer_.size()) {
*buffer_attributes = nullptr;
return nullptr;
}
*buffer_attributes = &(buffer_[idx].buffer_attributes_);
return buffer_[idx].buffer_;
}
size_t
MemoryReference::AddBuffer(
const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
{
total_byte_size_ += byte_size;
buffer_count_++;
buffer_.emplace_back(buffer, byte_size, memory_type, memory_type_id);
return buffer_.size() - 1;
}
size_t
MemoryReference::AddBuffer(
const char* buffer, BufferAttributes* buffer_attributes)
{
total_byte_size_ += buffer_attributes->ByteSize();
buffer_count_++;
buffer_.emplace_back(buffer, buffer_attributes);
return buffer_.size() - 1;
}
size_t
MemoryReference::AddBufferFront(
const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
{
total_byte_size_ += byte_size;
buffer_count_++;
buffer_.emplace(
buffer_.begin(), buffer, byte_size, memory_type, memory_type_id);
return buffer_.size() - 1;
}
//
// MutableMemory
//
MutableMemory::MutableMemory(
char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
: Memory(), buffer_(buffer),
buffer_attributes_(
BufferAttributes(byte_size, memory_type, memory_type_id, nullptr))
{
total_byte_size_ = byte_size;
buffer_count_ = (byte_size == 0) ? 0 : 1;
}
const char*
MutableMemory::BufferAt(
size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type,
int64_t* memory_type_id) const
{
if (idx != 0) {
*byte_size = 0;
*memory_type = TRITONSERVER_MEMORY_CPU;
*memory_type_id = 0;
return nullptr;
}
*byte_size = total_byte_size_;
*memory_type = buffer_attributes_.MemoryType();
*memory_type_id = buffer_attributes_.MemoryTypeId();
return buffer_;
}
const char*
MutableMemory::BufferAt(size_t idx, BufferAttributes** buffer_attributes)
{
if (idx != 0) {
*buffer_attributes = nullptr;
return nullptr;
}
*buffer_attributes = &buffer_attributes_;
return buffer_;
}
char*
MutableMemory::MutableBuffer(
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id)
{
if (memory_type != nullptr) {
*memory_type = buffer_attributes_.MemoryType();
}
if (memory_type_id != nullptr) {
*memory_type_id = buffer_attributes_.MemoryTypeId();
}
return buffer_;
}
//
// AllocatedMemory
//
AllocatedMemory::AllocatedMemory(
size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
: MutableMemory(nullptr, byte_size, memory_type, memory_type_id)
{
if (total_byte_size_ != 0) {
// Allocate memory with the following fallback policy:
// CUDA memory -> pinned system memory -> non-pinned system memory
switch (buffer_attributes_.MemoryType()) {
#ifdef TRITON_ENABLE_GPU
case TRITONSERVER_MEMORY_GPU: {
auto status = CudaMemoryManager::Alloc(
(void**)&buffer_, total_byte_size_,
buffer_attributes_.MemoryTypeId());
if (!status.IsOk()) {
static bool warning_logged = false;
if (!warning_logged) {
LOG_WARNING << status.Message()
<< ", falling back to pinned system memory";
warning_logged = true;
}
goto pinned_memory_allocation;
}
break;
}
pinned_memory_allocation:
#endif // TRITON_ENABLE_GPU
default: {
TRITONSERVER_MemoryType memory_type = buffer_attributes_.MemoryType();
auto status = PinnedMemoryManager::Alloc(
(void**)&buffer_, total_byte_size_, &memory_type, true);
buffer_attributes_.SetMemoryType(memory_type);
if (!status.IsOk()) {
LOG_ERROR << status.Message();
buffer_ = nullptr;
}
break;
}
}
}
total_byte_size_ = (buffer_ == nullptr) ? 0 : total_byte_size_;
}
AllocatedMemory::~AllocatedMemory()
{
if (buffer_ != nullptr) {
switch (buffer_attributes_.MemoryType()) {
case TRITONSERVER_MEMORY_GPU: {
#ifdef TRITON_ENABLE_GPU
auto status =
CudaMemoryManager::Free(buffer_, buffer_attributes_.MemoryTypeId());
if (!status.IsOk()) {
LOG_ERROR << status.Message();
}
#endif // TRITON_ENABLE_GPU
break;
}
default: {
auto status = PinnedMemoryManager::Free(buffer_);
if (!status.IsOk()) {
LOG_ERROR << status.Message();
buffer_ = nullptr;
}
break;
}
}
buffer_ = nullptr;
}
}
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <vector>
#include "buffer_attributes.h"
#include "constants.h"
#include "status.h"
namespace triton { namespace core {
//
// Memory used to access data in inference requests
//
class Memory {
public:
// Get the 'idx'-th data block in the buffer. Using index to avoid
// maintaining internal state such that one buffer can be shared
// across multiple providers.
// 'idx' zero base index. Valid indices are continuous.
// 'byte_size' returns the byte size of the chunk of bytes.
// 'memory_type' returns the memory type of the chunk of bytes.
// 'memory_type_id' returns the memory type id of the chunk of bytes.
// Return the pointer to the data block. Returns nullptr if 'idx' is
// out of range
virtual const char* BufferAt(
size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type,
int64_t* memory_type_id) const = 0;
// Similar to the above BufferAt but with BufferAttributes.
virtual const char* BufferAt(
size_t idx, BufferAttributes** buffer_attributes) = 0;
// Get the number of contiguous buffers composing the memory.
size_t BufferCount() const { return buffer_count_; }
// Return the total byte size of the data buffer
size_t TotalByteSize() const { return total_byte_size_; }
protected:
Memory() : total_byte_size_(0), buffer_count_(0) {}
size_t total_byte_size_;
size_t buffer_count_;
};
//
// MemoryReference
//
class MemoryReference : public Memory {
public:
// Create a read-only data buffer as a reference to other data buffer
MemoryReference();
//\see Memory::BufferAt()
const char* BufferAt(
size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type,
int64_t* memory_type_id) const override;
const char* BufferAt(
size_t idx, BufferAttributes** buffer_attributes) override;
// Add a 'buffer' with 'byte_size' as part of this data buffer
// Return the index of the buffer
size_t AddBuffer(
const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id);
size_t AddBuffer(const char* buffer, BufferAttributes* buffer_attributes);
// Add a 'buffer' with 'byte_size' as part of this data buffer in the front
// Return the index of the buffer
size_t AddBufferFront(
const char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id);
private:
struct Block {
Block(
const char* buffer, size_t byte_size,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id)
: buffer_(buffer), buffer_attributes_(BufferAttributes(
byte_size, memory_type, memory_type_id, nullptr))
{
}
Block(const char* buffer, BufferAttributes* buffer_attributes)
: buffer_(buffer), buffer_attributes_(*buffer_attributes)
{
}
const char* buffer_;
BufferAttributes buffer_attributes_;
};
std::vector<Block> buffer_;
};
//
// MutableMemory
//
class MutableMemory : public Memory {
public:
// Create a mutable data buffer referencing to other data buffer.
MutableMemory(
char* buffer, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id);
virtual ~MutableMemory() {}
//\see Memory::BufferAt()
const char* BufferAt(
size_t idx, size_t* byte_size, TRITONSERVER_MemoryType* memory_type,
int64_t* memory_type_id) const override;
//\see Memory::BufferAt()
const char* BufferAt(
size_t idx, BufferAttributes** buffer_attributes) override;
// Return a pointer to the base address of the mutable buffer. If
// non-null 'memory_type' returns the memory type of the chunk of
// bytes. If non-null 'memory_type_id' returns the memory type id of
// the chunk of bytes.
char* MutableBuffer(
TRITONSERVER_MemoryType* memory_type = nullptr,
int64_t* memory_type_id = nullptr);
DISALLOW_COPY_AND_ASSIGN(MutableMemory);
protected:
MutableMemory() : Memory() {}
char* buffer_;
BufferAttributes buffer_attributes_;
};
//
// AllocatedMemory
//
class AllocatedMemory : public MutableMemory {
public:
// Create a continuous data buffer with 'byte_size', 'memory_type' and
// 'memory_type_id'. Note that the buffer may be created on different memeory
// type and memory type id if the original request type and id can not be
// satisfied, thus the function caller should always check the actual memory
// type and memory type id before use.
AllocatedMemory(
size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id);
~AllocatedMemory() override;
};
}} // namespace triton::core
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifdef TRITON_ENABLE_METRICS
#include "metric_family.h"
#include "metrics.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
//
// Implementation for TRITONSERVER_MetricFamily.
//
MetricFamily::MetricFamily(
TRITONSERVER_MetricKind kind, const char* name, const char* description)
{
auto registry = Metrics::GetRegistry();
switch (kind) {
case TRITONSERVER_METRIC_KIND_COUNTER:
family_ = reinterpret_cast<void*>(&prometheus::BuildCounter()
.Name(name)
.Help(description)
.Register(*registry));
break;
case TRITONSERVER_METRIC_KIND_GAUGE:
family_ = reinterpret_cast<void*>(&prometheus::BuildGauge()
.Name(name)
.Help(description)
.Register(*registry));
break;
default:
throw std::invalid_argument(
"Unsupported kind passed to MetricFamily constructor.");
}
kind_ = kind;
}
void*
MetricFamily::Add(std::map<std::string, std::string> label_map, Metric* metric)
{
void* prom_metric = nullptr;
switch (kind_) {
case TRITONSERVER_METRIC_KIND_COUNTER: {
auto counter_family_ptr =
reinterpret_cast<prometheus::Family<prometheus::Counter>*>(family_);
auto counter_ptr = &counter_family_ptr->Add(label_map);
prom_metric = reinterpret_cast<void*>(counter_ptr);
break;
}
case TRITONSERVER_METRIC_KIND_GAUGE: {
auto gauge_family_ptr =
reinterpret_cast<prometheus::Family<prometheus::Gauge>*>(family_);
auto gauge_ptr = &gauge_family_ptr->Add(label_map);
prom_metric = reinterpret_cast<void*>(gauge_ptr);
break;
}
default:
throw std::invalid_argument(
"Unsupported family kind passed to Metric constructor.");
}
std::lock_guard<std::mutex> lk(metric_mtx_);
++prom_metric_ref_cnt_[prom_metric];
child_metrics_.insert(metric);
return prom_metric;
}
void
MetricFamily::Remove(void* prom_metric, Metric* metric)
{
{
// Remove reference to dependent Metric object
std::lock_guard<std::mutex> lk(metric_mtx_);
child_metrics_.erase(metric);
}
if (prom_metric == nullptr) {
return;
}
{
std::lock_guard<std::mutex> lk(metric_mtx_);
const auto it = prom_metric_ref_cnt_.find(prom_metric);
if (it != prom_metric_ref_cnt_.end()) {
--it->second;
if (it->second == 0) {
prom_metric_ref_cnt_.erase(it);
} else {
// Done as it is not the last reference
return;
}
}
}
switch (kind_) {
case TRITONSERVER_METRIC_KIND_COUNTER: {
auto counter_family_ptr =
reinterpret_cast<prometheus::Family<prometheus::Counter>*>(family_);
auto counter_ptr = reinterpret_cast<prometheus::Counter*>(prom_metric);
counter_family_ptr->Remove(counter_ptr);
break;
}
case TRITONSERVER_METRIC_KIND_GAUGE: {
auto gauge_family_ptr =
reinterpret_cast<prometheus::Family<prometheus::Gauge>*>(family_);
auto gauge_ptr = reinterpret_cast<prometheus::Gauge*>(prom_metric);
gauge_family_ptr->Remove(gauge_ptr);
break;
}
default:
// Invalid kind should be caught in constructor
LOG_ERROR << "Unsupported kind in Metric destructor.";
break;
}
}
void
MetricFamily::InvalidateReferences()
{
std::lock_guard<std::mutex> lk(metric_mtx_);
for (auto& metric : child_metrics_) {
if (metric != nullptr) {
metric->Invalidate();
}
}
child_metrics_.clear();
}
MetricFamily::~MetricFamily()
{
if (NumMetrics() > 0) {
LOG_WARNING << "MetricFamily was deleted before its child Metrics, this "
"should not happen. Make sure to delete all child Metrics "
"before deleting their MetricFamily.";
}
InvalidateReferences();
// DLIS-4072: Support for removing metric families from registry
}
//
// Implementation for TRITONSERVER_Metric.
//
Metric::Metric(
TRITONSERVER_MetricFamily* family,
std::vector<const InferenceParameter*> labels)
{
family_ = reinterpret_cast<MetricFamily*>(family);
kind_ = family_->Kind();
// Create map of labels from InferenceParameters
std::map<std::string, std::string> label_map;
for (const auto& param : labels) {
if (param->Type() != TRITONSERVER_PARAMETER_STRING) {
throw std::invalid_argument(
"Parameter [" + param->Name() +
"] must have a type of TRITONSERVER_PARAMETER_STRING to be "
"added as a label.");
}
label_map[param->Name()] =
std::string(reinterpret_cast<const char*>(param->ValuePointer()));
}
metric_ = family_->Add(label_map, this);
}
Metric::~Metric()
{
if (family_ != nullptr) {
family_->Remove(metric_, this);
} else {
LOG_WARNING << "Corresponding MetricFamily was deleted before this Metric, "
"this should not happen. Make sure to delete a Metric "
"before deleting its MetricFamily.";
}
// Catch lifetime management / invalid reference issues
Invalidate();
}
void
Metric::Invalidate()
{
family_ = nullptr;
metric_ = nullptr;
}
TRITONSERVER_Error*
Metric::Value(double* value)
{
if (metric_ == nullptr) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"Could not get metric value. Metric has been invalidated.");
}
switch (kind_) {
case TRITONSERVER_METRIC_KIND_COUNTER: {
auto counter_ptr = reinterpret_cast<prometheus::Counter*>(metric_);
LOG_VERBOSE(1) << "SETTING COUNTER METRIC FROM: " << *value << " to "
<< counter_ptr->Value();
*value = counter_ptr->Value();
break;
}
case TRITONSERVER_METRIC_KIND_GAUGE: {
auto gauge_ptr = reinterpret_cast<prometheus::Gauge*>(metric_);
LOG_VERBOSE(1) << "SETTING GAUGE METRIC FROM: " << *value << " to "
<< gauge_ptr->Value();
*value = gauge_ptr->Value();
break;
}
default:
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
"Unsupported TRITONSERVER_MetricKind");
}
return nullptr; // Success
}
TRITONSERVER_Error*
Metric::Increment(double value)
{
if (metric_ == nullptr) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"Could not increment metric value. Metric has been invalidated.");
}
switch (kind_) {
case TRITONSERVER_METRIC_KIND_COUNTER: {
if (value < 0.0) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"TRITONSERVER_METRIC_KIND_COUNTER can only be incremented "
"monotonically by non-negative values.");
}
auto counter_ptr = reinterpret_cast<prometheus::Counter*>(metric_);
counter_ptr->Increment(value);
break;
}
case TRITONSERVER_METRIC_KIND_GAUGE: {
auto gauge_ptr = reinterpret_cast<prometheus::Gauge*>(metric_);
// Gauge::Increment works for both positive and negative values as of
// prometheus-cpp v1.0 but for now on v0.7 we defer call to
// Increment/Decrement based on the sign of value
// https://github.com/jupp0r/prometheus-cpp/blob/master/core/src/gauge.cc
if (value < 0.0) {
gauge_ptr->Decrement(-1.0 * value);
} else {
gauge_ptr->Increment(value);
}
break;
}
default:
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
"Unsupported TRITONSERVER_MetricKind");
}
return nullptr; // Success
}
TRITONSERVER_Error*
Metric::Set(double value)
{
if (metric_ == nullptr) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"Could not set metric value. Metric has been invalidated.");
}
switch (kind_) {
case TRITONSERVER_METRIC_KIND_COUNTER: {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
"TRITONSERVER_METRIC_KIND_COUNTER does not support Set");
}
case TRITONSERVER_METRIC_KIND_GAUGE: {
auto gauge_ptr = reinterpret_cast<prometheus::Gauge*>(metric_);
gauge_ptr->Set(value);
break;
}
default:
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_UNSUPPORTED,
"Unsupported TRITONSERVER_MetricKind");
}
return nullptr; // Success
}
}} // namespace triton::core
#endif // TRITON_ENABLE_METRICS
// Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef TRITON_ENABLE_METRICS
#include <mutex>
#include <set>
#include <unordered_map>
#include "infer_parameter.h"
#include "prometheus/registry.h"
#include "tritonserver_apis.h"
namespace triton { namespace core {
//
// Implementation for TRITONSERVER_MetricFamily.
//
class Metric;
class MetricFamily {
public:
MetricFamily(
TRITONSERVER_MetricKind kind, const char* name, const char* description);
~MetricFamily();
void* Family() const { return family_; }
TRITONSERVER_MetricKind Kind() const { return kind_; }
void* Add(std::map<std::string, std::string> label_map, Metric* metric);
void Remove(void* prom_metric, Metric* metric);
int NumMetrics()
{
std::lock_guard<std::mutex> lk(metric_mtx_);
return child_metrics_.size();
}
private:
// If a MetricFamily is deleted before its dependent Metric, we want to
// invalidate the reference so we don't access invalid memory.
void InvalidateReferences();
void* family_;
TRITONSERVER_MetricKind kind_;
// Synchronize access of related metric objects
std::mutex metric_mtx_;
// Prometheus returns the existing metric pointer if the metric with the same
// set of labels are requested, as a result, different Metric objects may
// refer to the same prometheus metric. So we must track the reference count
// of the metric and request prometheus to remove it only when all references
// are released.
std::unordered_map<void*, size_t> prom_metric_ref_cnt_;
// Maintain references to metrics created from this metric family to
// invalidate their references if a family is deleted before its metric
std::set<Metric*> child_metrics_;
};
//
// Implementation for TRITONSERVER_Metric.
//
class Metric {
public:
Metric(
TRITONSERVER_MetricFamily* family,
std::vector<const InferenceParameter*> labels);
~Metric();
MetricFamily* Family() const { return family_; }
TRITONSERVER_MetricKind Kind() const { return kind_; }
TRITONSERVER_Error* Value(double* value);
TRITONSERVER_Error* Increment(double value);
TRITONSERVER_Error* Set(double value);
// If a MetricFamily is deleted before its dependent Metric, we want to
// invalidate the references so we don't access invalid memory.
void Invalidate();
private:
void* metric_;
MetricFamily* family_;
TRITONSERVER_MetricKind kind_;
};
}} // namespace triton::core
#endif // TRITON_ENABLE_METRICS
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "metric_model_reporter.h"
#ifdef TRITON_ENABLE_METRICS
#include "constants.h"
#include "metrics.h"
namespace triton { namespace core {
Status
MetricModelReporter::Create(
const std::string& model_name, const int64_t model_version,
const int device, const triton::common::MetricTagsMap& model_tags,
std::shared_ptr<MetricModelReporter>* metric_model_reporter)
{
static std::mutex mtx;
static std::unordered_map<size_t, std::weak_ptr<MetricModelReporter>>
reporter_map;
std::map<std::string, std::string> labels;
GetMetricLabels(&labels, model_name, model_version, device, model_tags);
auto hash_labels = Metrics::HashLabels(labels);
std::lock_guard<std::mutex> lock(mtx);
const auto& itr = reporter_map.find(hash_labels);
if (itr != reporter_map.end()) {
// Found in map. If the weak_ptr is still valid that means that
// there are other models using the reporter and we just reuse that
// same reporter. If the weak_ptr is not valid then we need to remove
// the weak_ptr from the map and create the reporter again.
*metric_model_reporter = itr->second.lock();
if (*metric_model_reporter != nullptr) {
return Status::Success;
}
reporter_map.erase(itr);
}
metric_model_reporter->reset(
new MetricModelReporter(model_name, model_version, device, model_tags));
reporter_map.insert({hash_labels, *metric_model_reporter});
return Status::Success;
}
MetricModelReporter::MetricModelReporter(
const std::string& model_name, const int64_t model_version,
const int device, const triton::common::MetricTagsMap& model_tags)
{
std::map<std::string, std::string> labels;
GetMetricLabels(&labels, model_name, model_version, device, model_tags);
metric_inf_success_ =
CreateCounterMetric(Metrics::FamilyInferenceSuccess(), labels);
metric_inf_failure_ =
CreateCounterMetric(Metrics::FamilyInferenceFailure(), labels);
metric_inf_count_ =
CreateCounterMetric(Metrics::FamilyInferenceCount(), labels);
metric_inf_exec_count_ =
CreateCounterMetric(Metrics::FamilyInferenceExecutionCount(), labels);
metric_inf_request_duration_us_ =
CreateCounterMetric(Metrics::FamilyInferenceRequestDuration(), labels);
metric_inf_queue_duration_us_ =
CreateCounterMetric(Metrics::FamilyInferenceQueueDuration(), labels);
metric_inf_compute_input_duration_us_ = CreateCounterMetric(
Metrics::FamilyInferenceComputeInputDuration(), labels);
metric_inf_compute_infer_duration_us_ = CreateCounterMetric(
Metrics::FamilyInferenceComputeInferDuration(), labels);
metric_inf_compute_output_duration_us_ = CreateCounterMetric(
Metrics::FamilyInferenceComputeOutputDuration(), labels);
metric_cache_hit_count_ =
CreateCounterMetric(Metrics::FamilyCacheHitCount(), labels);
metric_cache_hit_lookup_duration_us_ =
CreateCounterMetric(Metrics::FamilyCacheHitLookupDuration(), labels);
metric_cache_miss_count_ =
CreateCounterMetric(Metrics::FamilyCacheMissCount(), labels);
metric_cache_miss_lookup_duration_us_ =
CreateCounterMetric(Metrics::FamilyCacheMissLookupDuration(), labels);
metric_cache_miss_insertion_duration_us_ =
CreateCounterMetric(Metrics::FamilyCacheMissInsertionDuration(), labels);
}
MetricModelReporter::~MetricModelReporter()
{
Metrics::FamilyInferenceSuccess().Remove(metric_inf_success_);
Metrics::FamilyInferenceFailure().Remove(metric_inf_failure_);
Metrics::FamilyInferenceCount().Remove(metric_inf_count_);
Metrics::FamilyInferenceExecutionCount().Remove(metric_inf_exec_count_);
Metrics::FamilyInferenceRequestDuration().Remove(
metric_inf_request_duration_us_);
Metrics::FamilyInferenceQueueDuration().Remove(metric_inf_queue_duration_us_);
Metrics::FamilyInferenceComputeInputDuration().Remove(
metric_inf_compute_input_duration_us_);
Metrics::FamilyInferenceComputeInferDuration().Remove(
metric_inf_compute_infer_duration_us_);
Metrics::FamilyInferenceComputeOutputDuration().Remove(
metric_inf_compute_output_duration_us_);
Metrics::FamilyCacheHitCount().Remove(metric_cache_hit_count_);
Metrics::FamilyCacheHitLookupDuration().Remove(
metric_cache_hit_lookup_duration_us_);
Metrics::FamilyCacheMissCount().Remove(metric_cache_miss_count_);
Metrics::FamilyCacheMissInsertionDuration().Remove(
metric_cache_miss_insertion_duration_us_);
}
void
MetricModelReporter::GetMetricLabels(
std::map<std::string, std::string>* labels, const std::string& model_name,
const int64_t model_version, const int device,
const triton::common::MetricTagsMap& model_tags)
{
labels->insert(std::map<std::string, std::string>::value_type(
std::string(kMetricsLabelModelName), model_name));
labels->insert(std::map<std::string, std::string>::value_type(
std::string(kMetricsLabelModelVersion), std::to_string(model_version)));
for (const auto& tag : model_tags) {
labels->insert(std::map<std::string, std::string>::value_type(
"_" + tag.first, tag.second));
}
// 'device' can be < 0 to indicate that the GPU is not known. In
// that case use a metric that doesn't have the gpu_uuid label.
if (device >= 0) {
std::string uuid;
if (Metrics::UUIDForCudaDevice(device, &uuid)) {
labels->insert(std::map<std::string, std::string>::value_type(
std::string(kMetricsLabelGpuUuid), uuid));
}
}
}
prometheus::Counter*
MetricModelReporter::CreateCounterMetric(
prometheus::Family<prometheus::Counter>& family,
const std::map<std::string, std::string>& labels)
{
return &family.Add(labels);
}
}} // namespace triton::core
#endif // TRITON_ENABLE_METRICS
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "status.h"
#include "triton/common/model_config.h"
#ifdef TRITON_ENABLE_METRICS
#include "prometheus/registry.h"
#endif // TRITON_ENABLE_METRICS
namespace triton { namespace core {
//
// Interface for a metric reporter for a given version of a model.
//
class MetricModelReporter {
public:
#ifdef TRITON_ENABLE_METRICS
static Status Create(
const std::string& model_name, const int64_t model_version,
const int device, const triton::common::MetricTagsMap& model_tags,
std::shared_ptr<MetricModelReporter>* metric_model_reporter);
~MetricModelReporter();
// Get a metric for the given model, version and GPU index.
prometheus::Counter& MetricInferenceSuccess() const
{
return *metric_inf_success_;
}
prometheus::Counter& MetricInferenceFailure() const
{
return *metric_inf_failure_;
}
prometheus::Counter& MetricInferenceCount() const
{
return *metric_inf_count_;
}
prometheus::Counter& MetricInferenceExecutionCount() const
{
return *metric_inf_exec_count_;
}
prometheus::Counter& MetricInferenceRequestDuration() const
{
return *metric_inf_request_duration_us_;
}
prometheus::Counter& MetricInferenceQueueDuration() const
{
return *metric_inf_queue_duration_us_;
}
prometheus::Counter& MetricInferenceComputeInputDuration() const
{
return *metric_inf_compute_input_duration_us_;
}
prometheus::Counter& MetricInferenceComputeInferDuration() const
{
return *metric_inf_compute_infer_duration_us_;
}
prometheus::Counter& MetricInferenceComputeOutputDuration() const
{
return *metric_inf_compute_output_duration_us_;
}
prometheus::Counter& MetricCacheHitCount() const
{
return *metric_cache_hit_count_;
}
prometheus::Counter& MetricCacheHitLookupDuration() const
{
return *metric_cache_hit_lookup_duration_us_;
}
prometheus::Counter& MetricCacheMissCount() const
{
return *metric_cache_miss_count_;
}
prometheus::Counter& MetricCacheMissLookupDuration() const
{
return *metric_cache_miss_lookup_duration_us_;
}
prometheus::Counter& MetricCacheMissInsertionDuration() const
{
return *metric_cache_miss_insertion_duration_us_;
}
private:
MetricModelReporter(
const std::string& model_name, const int64_t model_version,
const int device, const triton::common::MetricTagsMap& model_tags);
static void GetMetricLabels(
std::map<std::string, std::string>* labels, const std::string& model_name,
const int64_t model_version, const int device,
const triton::common::MetricTagsMap& model_tags);
prometheus::Counter* CreateCounterMetric(
prometheus::Family<prometheus::Counter>& family,
const std::map<std::string, std::string>& labels);
prometheus::Counter* metric_inf_success_;
prometheus::Counter* metric_inf_failure_;
prometheus::Counter* metric_inf_count_;
prometheus::Counter* metric_inf_exec_count_;
prometheus::Counter* metric_inf_request_duration_us_;
prometheus::Counter* metric_inf_queue_duration_us_;
prometheus::Counter* metric_inf_compute_input_duration_us_;
prometheus::Counter* metric_inf_compute_infer_duration_us_;
prometheus::Counter* metric_inf_compute_output_duration_us_;
prometheus::Counter* metric_cache_hit_count_;
prometheus::Counter* metric_cache_hit_lookup_duration_us_;
prometheus::Counter* metric_cache_miss_count_;
prometheus::Counter* metric_cache_miss_lookup_duration_us_;
prometheus::Counter* metric_cache_miss_insertion_duration_us_;
#endif // TRITON_ENABLE_METRICS
};
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#ifdef TRITON_ENABLE_METRICS
#include "metrics.h"
#include <thread>
#include "constants.h"
#include "prometheus/detail/utils.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_METRICS_GPU
#include <cuda_runtime_api.h>
#include <dcgm_agent.h>
#include <cstring>
#include <set>
#include <string>
#endif // TRITON_ENABLE_METRICS_GPU
namespace triton { namespace core {
Metrics::Metrics()
: registry_(std::make_shared<prometheus::Registry>()),
serializer_(new prometheus::TextSerializer()),
inf_success_family_(
prometheus::BuildCounter()
.Name("nv_inference_request_success")
.Help("Number of successful inference requests, all batch sizes")
.Register(*registry_)),
inf_failure_family_(
prometheus::BuildCounter()
.Name("nv_inference_request_failure")
.Help("Number of failed inference requests, all batch sizes")
.Register(*registry_)),
inf_count_family_(prometheus::BuildCounter()
.Name("nv_inference_count")
.Help("Number of inferences performed (does not "
"include cached requests)")
.Register(*registry_)),
inf_count_exec_family_(prometheus::BuildCounter()
.Name("nv_inference_exec_count")
.Help("Number of model executions performed "
"(does not include cached requests)")
.Register(*registry_)),
inf_request_duration_us_family_(
prometheus::BuildCounter()
.Name("nv_inference_request_duration_us")
.Help("Cumulative inference request duration in microseconds "
"(includes cached requests)")
.Register(*registry_)),
inf_queue_duration_us_family_(
prometheus::BuildCounter()
.Name("nv_inference_queue_duration_us")
.Help("Cumulative inference queuing duration in microseconds "
"(includes cached requests)")
.Register(*registry_)),
inf_compute_input_duration_us_family_(
prometheus::BuildCounter()
.Name("nv_inference_compute_input_duration_us")
.Help("Cumulative compute input duration in microseconds (does "
"not include cached requests)")
.Register(*registry_)),
inf_compute_infer_duration_us_family_(
prometheus::BuildCounter()
.Name("nv_inference_compute_infer_duration_us")
.Help("Cumulative compute inference duration in microseconds "
"(does not include cached requests)")
.Register(*registry_)),
inf_compute_output_duration_us_family_(
prometheus::BuildCounter()
.Name("nv_inference_compute_output_duration_us")
.Help("Cumulative inference compute output duration in "
"microseconds (does not include cached requests)")
.Register(*registry_)),
cache_num_entries_family_(
prometheus::BuildGauge()
.Name("nv_cache_num_entries")
.Help("Number of responses stored in response cache")
.Register(*registry_)),
cache_num_lookups_family_(
prometheus::BuildGauge()
.Name("nv_cache_num_lookups")
.Help("Number of cache lookups in response cache")
.Register(*registry_)),
cache_num_hits_family_(prometheus::BuildGauge()
.Name("nv_cache_num_hits")
.Help("Number of cache hits in response cache")
.Register(*registry_)),
cache_num_misses_family_(
prometheus::BuildGauge()
.Name("nv_cache_num_misses")
.Help("Number of cache misses in response cache")
.Register(*registry_)),
cache_num_evictions_family_(
prometheus::BuildGauge()
.Name("nv_cache_num_evictions")
.Help("Number of cache evictions in response cache")
.Register(*registry_)),
cache_lookup_duration_us_family_(
prometheus::BuildGauge()
.Name("nv_cache_lookup_duration")
.Help(
"Total cache lookup duration (hit and miss), in microseconds")
.Register(*registry_)),
cache_insertion_duration_us_family_(
prometheus::BuildGauge()
.Name("nv_cache_insertion_duration")
.Help("Total cache insertion duration, in microseconds")
.Register(*registry_)),
cache_util_family_(prometheus::BuildGauge()
.Name("nv_cache_util")
.Help("Cache utilization [0.0 - 1.0]")
.Register(*registry_)),
// Per-model cache metric families
cache_num_hits_model_family_(prometheus::BuildCounter()
.Name("nv_cache_num_hits_per_model")
.Help("Number of cache hits per model")
.Register(*registry_)),
cache_hit_lookup_duration_us_model_family_(
prometheus::BuildCounter()
.Name("nv_cache_hit_lookup_duration_per_model")
.Help(
"Total cache hit lookup duration per model, in microseconds")
.Register(*registry_)),
cache_num_misses_model_family_(
prometheus::BuildCounter()
.Name("nv_cache_num_misses_per_model")
.Help("Number of cache misses per model")
.Register(*registry_)),
cache_miss_lookup_duration_us_model_family_(
prometheus::BuildCounter()
.Name("nv_cache_miss_lookup_duration_per_model")
.Help(
"Total cache miss lookup duration per model, in microseconds")
.Register(*registry_)),
cache_miss_insertion_duration_us_model_family_(
prometheus::BuildCounter()
.Name("nv_cache_miss_insertion_duration_per_model")
.Help("Total cache miss insertion duration per model, in "
"microseconds")
.Register(*registry_)),
#ifdef TRITON_ENABLE_METRICS_GPU
gpu_utilization_family_(prometheus::BuildGauge()
.Name("nv_gpu_utilization")
.Help("GPU utilization rate [0.0 - 1.0)")
.Register(*registry_)),
gpu_memory_total_family_(prometheus::BuildGauge()
.Name("nv_gpu_memory_total_bytes")
.Help("GPU total memory, in bytes")
.Register(*registry_)),
gpu_memory_used_family_(prometheus::BuildGauge()
.Name("nv_gpu_memory_used_bytes")
.Help("GPU used memory, in bytes")
.Register(*registry_)),
gpu_power_usage_family_(prometheus::BuildGauge()
.Name("nv_gpu_power_usage")
.Help("GPU power usage in watts")
.Register(*registry_)),
gpu_power_limit_family_(prometheus::BuildGauge()
.Name("nv_gpu_power_limit")
.Help("GPU power management limit in watts")
.Register(*registry_)),
gpu_energy_consumption_family_(
prometheus::BuildCounter()
.Name("nv_energy_consumption")
.Help("GPU energy consumption in joules since the Triton Server "
"started")
.Register(*registry_)),
#endif // TRITON_ENABLE_METRICS_GPU
#ifdef TRITON_ENABLE_METRICS_CPU
cpu_utilization_family_(prometheus::BuildGauge()
.Name("nv_cpu_utilization")
.Help("CPU utilization rate [0.0 - 1.0]")
.Register(*registry_)),
cpu_memory_total_family_(prometheus::BuildGauge()
.Name("nv_cpu_memory_total_bytes")
.Help("CPU total memory (RAM), in bytes")
.Register(*registry_)),
cpu_memory_used_family_(prometheus::BuildGauge()
.Name("nv_cpu_memory_used_bytes")
.Help("CPU used memory (RAM), in bytes")
.Register(*registry_)),
#endif // TRITON_ENABLE_METRICS_CPU
metrics_enabled_(false), gpu_metrics_enabled_(false),
cpu_metrics_enabled_(false), cache_metrics_enabled_(false),
metrics_interval_ms_(2000)
{
}
static prometheus::detail::LabelHasher label_hasher_;
size_t
Metrics::HashLabels(const std::map<std::string, std::string>& labels)
{
return label_hasher_(labels);
}
Metrics::~Metrics()
{
// Signal the cache thread to exit and then wait for it...
if (poll_thread_ != nullptr) {
poll_thread_exit_.store(true);
poll_thread_->join();
#ifdef TRITON_ENABLE_METRICS_GPU
if (dcgm_metadata_.dcgm_initialized_) {
dcgmReturn_t derr;
// Group destroy will return an error if groupId invalid or dcgm not
// initialized or configured correctly
derr = dcgmGroupDestroy(
dcgm_metadata_.dcgm_handle_, dcgm_metadata_.groupId_);
if (derr != DCGM_ST_OK) {
LOG_WARNING << "Unable to destroy DCGM group: " << errorString(derr);
}
// Stop and shutdown DCGM
if (dcgm_metadata_.standalone_) {
derr = dcgmDisconnect(dcgm_metadata_.dcgm_handle_);
} else {
derr = dcgmStopEmbedded(dcgm_metadata_.dcgm_handle_);
}
if (derr != DCGM_ST_OK) {
LOG_WARNING << "Unable to stop DCGM: " << errorString(derr);
}
derr = dcgmShutdown();
if (derr != DCGM_ST_OK) {
LOG_WARNING << "Unable to shutdown DCGM: " << errorString(derr);
}
}
#endif // TRITON_ENABLE_METRICS_GPU
}
}
bool
Metrics::Enabled()
{
auto singleton = GetSingleton();
return singleton->metrics_enabled_;
}
void
Metrics::EnableMetrics()
{
auto singleton = GetSingleton();
singleton->metrics_enabled_ = true;
}
void
Metrics::EnableCacheMetrics(
std::shared_ptr<RequestResponseCache> response_cache)
{
auto singleton = GetSingleton();
// Ensure thread-safe enabling of Cache Metrics
std::lock_guard<std::mutex> lock(singleton->metrics_enabling_);
if (singleton->cache_metrics_enabled_) {
return;
}
singleton->InitializeCacheMetrics(response_cache);
singleton->cache_metrics_enabled_ = true;
}
void
Metrics::EnableGPUMetrics()
{
auto singleton = GetSingleton();
// Ensure thread-safe enabling of GPU Metrics
std::lock_guard<std::mutex> lock(singleton->metrics_enabling_);
if (singleton->gpu_metrics_enabled_) {
return;
}
if (std::getenv("TRITON_SERVER_CPU_ONLY") == nullptr) {
singleton->InitializeDcgmMetrics();
}
singleton->gpu_metrics_enabled_ = true;
}
void
Metrics::EnableCpuMetrics()
{
auto singleton = GetSingleton();
// Ensure thread-safe enabling of CPU Metrics
std::lock_guard<std::mutex> lock(singleton->metrics_enabling_);
if (singleton->cpu_metrics_enabled_) {
return;
}
singleton->InitializeCpuMetrics();
singleton->cpu_metrics_enabled_ = true;
}
void
Metrics::SetMetricsInterval(uint64_t metrics_interval_ms)
{
auto singleton = GetSingleton();
singleton->metrics_interval_ms_ = metrics_interval_ms;
}
void
Metrics::StartPollingThreadSingleton(
std::shared_ptr<RequestResponseCache> response_cache)
{
auto singleton = GetSingleton();
// Ensure thread-safe start of polling thread
std::lock_guard<std::mutex> lock(singleton->poll_thread_starting_);
if (singleton->poll_thread_started_) {
return;
}
// Start thread for polling cache/dcgm metrics
singleton->StartPollingThread(response_cache);
// Toggle flag so this function is only executed once
singleton->poll_thread_started_ = true;
}
bool
Metrics::StartPollingThread(
std::shared_ptr<RequestResponseCache> response_cache)
{
// Nothing to poll if no polling metrics enabled, don't spawn a thread
if (!cache_metrics_enabled_ && !gpu_metrics_enabled_ &&
!cpu_metrics_enabled_) {
LOG_WARNING << "No polling metrics (CPU, GPU, Cache) are enabled. Will not "
"poll for them.";
return false;
}
poll_thread_exit_.store(false);
// Start a separate thread for polling metrics at specified interval
poll_thread_.reset(new std::thread([this, response_cache] {
// Thread will update metrics indefinitely until exit flag set
while (!poll_thread_exit_.load()) {
// Sleep for metric interval
std::this_thread::sleep_for(
std::chrono::milliseconds(metrics_interval_ms_ / 2));
// Poll Response Cache metrics
if (cache_metrics_enabled_ && response_cache != nullptr) {
PollCacheMetrics(response_cache);
}
#ifdef TRITON_ENABLE_METRICS_GPU
// Poll DCGM GPU metrics
if (gpu_metrics_enabled_ &&
dcgm_metadata_.available_cuda_gpu_ids_.size() > 0) {
PollDcgmMetrics();
}
#endif // TRITON_ENABLE_METRICS_GPU
#ifdef TRITON_ENABLE_METRICS_CPU
if (cpu_metrics_enabled_) {
PollCpuMetrics();
}
#endif // TRITON_ENABLE_METRICS_CPU
}
}));
return true;
}
bool
Metrics::PollCacheMetrics(std::shared_ptr<RequestResponseCache> response_cache)
{
if (response_cache == nullptr) {
LOG_WARNING << "error polling cache metrics, cache metrics will not be "
<< "available: cache was nullptr";
return false;
}
// Update global cache metrics
cache_num_entries_global_->Set(response_cache->NumEntries());
cache_num_lookups_global_->Set(response_cache->NumLookups());
cache_num_hits_global_->Set(response_cache->NumHits());
cache_num_misses_global_->Set(response_cache->NumMisses());
cache_num_evictions_global_->Set(response_cache->NumEvictions());
cache_lookup_duration_us_global_->Set(
response_cache->TotalLookupLatencyNs() / 1000);
cache_insertion_duration_us_global_->Set(
response_cache->TotalInsertionLatencyNs() / 1000);
cache_util_global_->Set(response_cache->TotalUtilization());
return true;
}
#ifdef TRITON_ENABLE_METRICS_CPU
Status
Metrics::ParseCpuInfo(CpuInfo& info)
{
#ifdef _WIN32
return Status(
Status::Code::INTERNAL, "CPU metrics not supported on Windows.");
#else
std::ifstream ifs("/proc/stat");
if (!ifs.good()) {
return Status(Status::Code::INTERNAL, "Failed to open /proc/stat.");
}
std::string line;
// Verify first line is aggregate cpu line
std::getline(ifs, line);
if (line.rfind("cpu ", 0) == std::string::npos) {
return Status(
Status::Code::INTERNAL,
"Failed to find aggregate CPU info in /proc/stat.");
}
std::string _;
std::istringstream iss(line);
// Use _ to skip "cpu" at start of line
if (!(iss >> _ >> info)) {
return Status(
Status::Code::INTERNAL,
"Failed to parse aggregate CPU info in /proc/stat.");
}
return Status::Success;
#endif // OS
}
Status
Metrics::ParseMemInfo(MemInfo& info)
{
#ifdef _WIN32
return Status(
Status::Code::INTERNAL, "Memory metrics not supported on Windows.");
#else
std::ifstream ifs("/proc/meminfo");
if (!ifs.good()) {
return Status(Status::Code::INTERNAL, "Failed to open /proc/meminfo.");
}
std::string line;
constexpr uint64_t KB = 1024;
while (std::getline(ifs, line)) {
std::istringstream iss(line);
std::string name;
uint64_t value = 0;
if (iss >> name >> value) {
name.pop_back();
info[name] = value * KB;
} else {
return Status(
Status::Code::INTERNAL, "Encountered error parsing /proc/meminfo.");
}
}
if (info.find("MemTotal") == info.end() ||
info.find("MemAvailable") == info.end()) {
return Status(
Status::Code::INTERNAL,
"Failed to find desired values in /proc/meminfo.");
}
if (info["MemAvailable"] > info["MemTotal"]) {
return Status(
Status::Code::INTERNAL,
"Available bytes shouldn't be greater than Total bytes");
}
// "Used" memory can be defined in many different ways. While many
// older applications consider "used = total - (free + cached)", a more
// accurate measure of available memory "MemAvailable" was added,
// so we choose "used = total - available" for a more accurate measure.
// This may change in the future if not sufficient for most use cases.
// See https://stackoverflow.com/a/35019697.
info["MemUsed"] = info["MemTotal"] - info["MemAvailable"];
return Status::Success;
#endif // OS
}
double
Metrics::CpuUtilization(const CpuInfo& info_new, const CpuInfo& info_old)
{
// Account for overflow
const auto wrap_sub = [](uint64_t a, uint64_t b) {
return (a > b) ? (a - b) : 0;
};
uint64_t util_diff = wrap_sub(info_new.user, info_old.user) +
wrap_sub(info_new.nice, info_old.nice) +
wrap_sub(info_new.system, info_old.system) +
wrap_sub(info_new.irq, info_old.irq) +
wrap_sub(info_new.softirq, info_old.softirq) +
wrap_sub(info_new.steal, info_old.steal);
uint64_t idle_diff = wrap_sub(info_new.idle, info_old.idle) +
wrap_sub(info_new.iowait, info_old.iowait);
double util_ratio = static_cast<double>(util_diff) / (util_diff + idle_diff);
return util_ratio;
}
#endif // TRITON_ENABLE_METRICS_CPU
bool
Metrics::PollCpuMetrics()
{
#ifndef TRITON_ENABLE_METRICS_CPU
return false;
#else
// CPU Utilization
double cpu_util = 0.0;
auto cpu_info = CpuInfo();
auto status = ParseCpuInfo(cpu_info);
if (status.IsOk()) {
cpu_util = CpuUtilization(cpu_info, last_cpu_info_);
last_cpu_info_ = cpu_info;
}
cpu_utilization_->Set(cpu_util); // [0.0, 1.0]
// RAM / Memory
double mem_total_bytes = 0.0;
double mem_used_bytes = 0.0;
auto mem_info = MemInfo();
status = ParseMemInfo(mem_info);
if (status.IsOk()) {
// MemTotal will usually not change over time, but if something
// goes wrong when querying memory, we can reflect that by updating.
mem_total_bytes = mem_info["MemTotal"];
mem_used_bytes = mem_info["MemUsed"];
}
cpu_memory_total_->Set(mem_total_bytes);
cpu_memory_used_->Set(mem_used_bytes);
return true;
#endif // TRITON_ENABLE_METRICS_CPU
}
bool
Metrics::PollDcgmMetrics()
{
#ifndef TRITON_ENABLE_METRICS_GPU
return false;
#else
if (dcgm_metadata_.available_cuda_gpu_ids_.size() == 0) {
LOG_WARNING << "error polling GPU metrics, GPU metrics will not be "
<< "available: no available gpus to poll";
return false;
}
dcgmUpdateAllFields(dcgm_metadata_.dcgm_handle_, 1 /* wait for update*/);
for (unsigned int didx = 0;
didx < dcgm_metadata_.available_cuda_gpu_ids_.size(); ++didx) {
uint32_t cuda_id = dcgm_metadata_.available_cuda_gpu_ids_[didx];
if (dcgm_metadata_.cuda_ids_to_dcgm_ids_.count(cuda_id) <= 0) {
LOG_WARNING << "Cannot find DCGM id for CUDA id " << cuda_id;
continue;
}
uint32_t dcgm_id = dcgm_metadata_.cuda_ids_to_dcgm_ids_.at(cuda_id);
dcgmFieldValue_v1 field_values[dcgm_metadata_.field_count_];
dcgmReturn_t dcgmerr = dcgmGetLatestValuesForFields(
dcgm_metadata_.dcgm_handle_, dcgm_id, dcgm_metadata_.fields_.data(),
dcgm_metadata_.field_count_, field_values);
if (dcgmerr != DCGM_ST_OK) {
dcgm_metadata_.power_limit_fail_cnt_[didx]++;
dcgm_metadata_.power_usage_fail_cnt_[didx]++;
dcgm_metadata_.energy_fail_cnt_[didx]++;
dcgm_metadata_.util_fail_cnt_[didx]++;
dcgm_metadata_.mem_fail_cnt_[didx]++;
LOG_WARNING << "Unable to get field values for GPU ID " << cuda_id << ": "
<< errorString(dcgmerr);
} else {
// Power limit
if (dcgm_metadata_.power_limit_fail_cnt_[didx] <
dcgm_metadata_.fail_threshold_) {
double power_limit = field_values[0].value.dbl;
if ((field_values[0].status == DCGM_ST_OK) &&
(!DCGM_FP64_IS_BLANK(power_limit))) {
dcgm_metadata_.power_limit_fail_cnt_[didx] = 0;
} else {
dcgm_metadata_.power_limit_fail_cnt_[didx]++;
power_limit = 0;
dcgmReturn_t status = dcgmReturn_t(field_values[0].status);
LOG_WARNING << "Unable to get power limit for GPU " << cuda_id
<< ". Status:" << errorString(status)
<< ", value:" << dcgmValueToErrorMessage(power_limit);
}
gpu_power_limit_[didx]->Set(power_limit);
}
// Power usage
if (dcgm_metadata_.power_usage_fail_cnt_[didx] <
dcgm_metadata_.fail_threshold_) {
double power_usage = field_values[1].value.dbl;
if ((field_values[1].status == DCGM_ST_OK) &&
(!DCGM_FP64_IS_BLANK(power_usage))) {
dcgm_metadata_.power_usage_fail_cnt_[didx] = 0;
} else {
dcgm_metadata_.power_usage_fail_cnt_[didx]++;
power_usage = 0;
dcgmReturn_t status = dcgmReturn_t(field_values[1].status);
LOG_WARNING << "Unable to get power usage for GPU " << cuda_id
<< ". Status:" << errorString(status)
<< ", value:" << dcgmValueToErrorMessage(power_usage);
}
gpu_power_usage_[didx]->Set(power_usage);
}
// Energy Consumption
if (dcgm_metadata_.energy_fail_cnt_[didx] <
dcgm_metadata_.fail_threshold_) {
int64_t energy = field_values[2].value.i64;
if ((field_values[2].status == DCGM_ST_OK) &&
(!DCGM_INT64_IS_BLANK(energy))) {
dcgm_metadata_.energy_fail_cnt_[didx] = 0;
if (dcgm_metadata_.last_energy_[didx] == 0) {
dcgm_metadata_.last_energy_[didx] = energy;
}
gpu_energy_consumption_[didx]->Increment(
(double)(energy - dcgm_metadata_.last_energy_[didx]) * 0.001);
dcgm_metadata_.last_energy_[didx] = energy;
} else {
dcgm_metadata_.energy_fail_cnt_[didx]++;
energy = 0;
dcgmReturn_t status = dcgmReturn_t(field_values[2].status);
LOG_WARNING << "Unable to get energy consumption for "
<< "GPU " << cuda_id << ". Status:" << errorString(status)
<< ", value:" << dcgmValueToErrorMessage(energy);
}
}
// Utilization
if (dcgm_metadata_.util_fail_cnt_[didx] <
dcgm_metadata_.fail_threshold_) {
int64_t util = field_values[3].value.i64;
if ((field_values[3].status == DCGM_ST_OK) &&
(!DCGM_INT64_IS_BLANK(util))) {
dcgm_metadata_.util_fail_cnt_[didx] = 0;
} else {
dcgm_metadata_.util_fail_cnt_[didx]++;
util = 0;
dcgmReturn_t status = dcgmReturn_t(field_values[3].status);
LOG_WARNING << "Unable to get GPU utilization for GPU " << cuda_id
<< ". Status:" << errorString(status)
<< ", value:" << dcgmValueToErrorMessage(util);
}
gpu_utilization_[didx]->Set((double)util * 0.01);
}
// Memory Usage
if (dcgm_metadata_.mem_fail_cnt_[didx] < dcgm_metadata_.fail_threshold_) {
int64_t memory_used = field_values[4].value.i64;
int64_t memory_total = field_values[5].value.i64;
if ((field_values[4].status == DCGM_ST_OK) &&
(!DCGM_INT64_IS_BLANK(memory_used)) &&
(field_values[5].status == DCGM_ST_OK) &&
(!DCGM_INT64_IS_BLANK(memory_total))) {
dcgm_metadata_.mem_fail_cnt_[didx] = 0;
} else {
memory_total = 0;
memory_used = 0;
dcgm_metadata_.mem_fail_cnt_[didx]++;
dcgmReturn_t usageStatus = dcgmReturn_t(field_values[4].status);
dcgmReturn_t memoryTotaltatus = dcgmReturn_t(field_values[5].status);
LOG_WARNING << "Unable to get memory usage for GPU " << cuda_id
<< ". Memory usage status:" << errorString(usageStatus)
<< ", value:" << dcgmValueToErrorMessage(memory_used)
<< ". Memory total status:"
<< errorString(memoryTotaltatus)
<< ", value:" << dcgmValueToErrorMessage(memory_total);
}
gpu_memory_total_[didx]->Set(memory_total * 1024 * 1024); // bytes
gpu_memory_used_[didx]->Set(memory_used * 1024 * 1024); // bytes
}
}
}
return true;
#endif // TRITON_ENABLE_METRICS_GPU
}
bool
Metrics::InitializeCacheMetrics(
std::shared_ptr<RequestResponseCache> response_cache)
{
if (response_cache == nullptr) {
LOG_WARNING
<< "error initializing cache metrics, cache metrics will not be "
<< "available: cache was nullptr";
return false;
}
const std::map<std::string, std::string> cache_labels;
cache_num_entries_global_ = &cache_num_entries_family_.Add(cache_labels);
cache_num_lookups_global_ = &cache_num_lookups_family_.Add(cache_labels);
cache_num_hits_global_ = &cache_num_hits_family_.Add(cache_labels);
cache_num_misses_global_ = &cache_num_misses_family_.Add(cache_labels);
cache_num_evictions_global_ = &cache_num_evictions_family_.Add(cache_labels);
cache_lookup_duration_us_global_ =
&cache_lookup_duration_us_family_.Add(cache_labels);
cache_insertion_duration_us_global_ =
&cache_insertion_duration_us_family_.Add(cache_labels);
cache_util_global_ = &cache_util_family_.Add(cache_labels);
LOG_INFO << "Collecting Response Cache metrics";
return true;
}
bool
Metrics::InitializeCpuMetrics()
{
#ifndef TRITON_ENABLE_METRICS_CPU
return false;
#else
const std::map<std::string, std::string> cpu_labels;
cpu_utilization_ = &cpu_utilization_family_.Add(cpu_labels);
cpu_memory_total_ = &cpu_memory_total_family_.Add(cpu_labels);
cpu_memory_used_ = &cpu_memory_used_family_.Add(cpu_labels);
// Get baseline CPU info for future comparisons
last_cpu_info_ = CpuInfo();
auto status = ParseCpuInfo(last_cpu_info_);
if (!status.IsOk()) {
LOG_WARNING << "error initializing CPU metrics, CPU utilization may not "
"be available: "
<< status.Message();
return false;
}
// Verify memory metrics can be parsed
auto mem_info = MemInfo();
status = ParseMemInfo(mem_info);
if (!status.IsOk()) {
LOG_WARNING << "error initializing CPU metrics, CPU memory metrics may not "
"be available: "
<< status.Message();
return false;
}
LOG_INFO << "Collecting CPU metrics";
return true;
#endif // TRITON_ENABLE_METRICS_CPU
}
bool
Metrics::InitializeDcgmMetrics()
{
#ifndef TRITON_ENABLE_METRICS_GPU
return false;
#else
dcgmReturn_t dcgmerr = dcgmInit();
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "error initializing DCGM, GPU metrics will not be "
<< "available: " << errorString(dcgmerr);
return false;
}
if (dcgm_metadata_.standalone_) {
char hostIpAddress[16] = {0};
std::string ipAddress = "127.0.0.1";
strncpy(hostIpAddress, ipAddress.c_str(), 15);
dcgmerr = dcgmConnect(hostIpAddress, &dcgm_metadata_.dcgm_handle_);
} else {
dcgmerr = dcgmStartEmbedded(
DCGM_OPERATION_MODE_MANUAL, &dcgm_metadata_.dcgm_handle_);
}
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "DCGM unable to start: " << errorString(dcgmerr);
return false;
} else {
// Set this flag to signal DCGM cleanup in destructor
dcgm_metadata_.dcgm_initialized_ = true;
}
if (dcgm_metadata_.standalone_) {
dcgmerr = dcgmUpdateAllFields(dcgm_metadata_.dcgm_handle_, 1);
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "DCGM unable to update all fields, GPU metrics will "
"not be available: "
<< errorString(dcgmerr);
return false;
}
}
unsigned int dcgm_gpu_ids[DCGM_MAX_NUM_DEVICES];
int dcgm_gpu_count;
dcgmerr = dcgmGetAllDevices(
dcgm_metadata_.dcgm_handle_, dcgm_gpu_ids, &dcgm_gpu_count);
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "DCGM unable to get device info and count, GPU "
"metrics will not be available: "
<< errorString(dcgmerr);
return false;
}
// Get PCI Bus ID to DCGM device Id map.
// Some devices may have problems using DCGM API and
// these devices needs to be ignored.
std::map<std::string, size_t> pci_bus_id_to_dcgm_id;
std::map<std::string, std::map<std::string, std::string> >
pci_bus_id_to_gpu_labels;
std::map<std::string, std::string> pci_bus_id_to_device_name;
dcgmDeviceAttributes_t gpu_attributes[DCGM_MAX_NUM_DEVICES];
for (int i = 0; i < dcgm_gpu_count; i++) {
gpu_attributes[i].version = dcgmDeviceAttributes_version;
dcgmerr = dcgmGetDeviceAttributes(
dcgm_metadata_.dcgm_handle_, dcgm_gpu_ids[i], &gpu_attributes[i]);
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "DCGM unable to get device properties for DCGM device "
<< dcgm_gpu_ids[i]
<< ", GPU metrics will not be available for this device: "
<< errorString(dcgmerr);
} else {
std::string pciBusId = gpu_attributes[i].identifiers.pciBusId;
pci_bus_id_to_dcgm_id[pciBusId] = i;
pci_bus_id_to_device_name[pciBusId] =
std::string(gpu_attributes[i].identifiers.deviceName);
std::map<std::string, std::string> gpu_labels;
gpu_labels.insert(std::map<std::string, std::string>::value_type(
kMetricsLabelGpuUuid,
std::string(gpu_attributes[i].identifiers.uuid)));
pci_bus_id_to_gpu_labels[pciBusId] = gpu_labels;
}
}
// Get CUDA-visible PCI Bus Ids and get DCGM metrics for each CUDA-visible GPU
int cuda_gpu_count;
cudaError_t cudaerr = cudaGetDeviceCount(&cuda_gpu_count);
if (cudaerr != cudaSuccess) {
LOG_WARNING
<< "Cannot get CUDA device count, GPU metrics will not be available";
return false;
}
for (int i = 0; i < cuda_gpu_count; ++i) {
std::string pci_bus_id = "0000"; // pad 0's for uniformity
char pcibusid_str[64];
cudaerr = cudaDeviceGetPCIBusId(pcibusid_str, sizeof(pcibusid_str) - 1, i);
if (cudaerr == cudaSuccess) {
pci_bus_id.append(pcibusid_str);
if (pci_bus_id_to_dcgm_id.count(pci_bus_id) <= 0) {
LOG_INFO << "Skipping GPU:" << i
<< " since it's not CUDA enabled. This should never happen!";
continue;
}
// Filter out CUDA visible GPUs from GPUs found by DCGM
LOG_INFO << "Collecting metrics for GPU " << i << ": "
<< pci_bus_id_to_device_name[pci_bus_id];
auto& gpu_labels = pci_bus_id_to_gpu_labels[pci_bus_id];
gpu_utilization_.push_back(&gpu_utilization_family_.Add(gpu_labels));
gpu_memory_total_.push_back(&gpu_memory_total_family_.Add(gpu_labels));
gpu_memory_used_.push_back(&gpu_memory_used_family_.Add(gpu_labels));
gpu_power_usage_.push_back(&gpu_power_usage_family_.Add(gpu_labels));
gpu_power_limit_.push_back(&gpu_power_limit_family_.Add(gpu_labels));
gpu_energy_consumption_.push_back(
&gpu_energy_consumption_family_.Add(gpu_labels));
uint32_t dcgm_id = pci_bus_id_to_dcgm_id[pci_bus_id];
dcgm_metadata_.cuda_ids_to_dcgm_ids_[i] = dcgm_id;
dcgm_metadata_.available_cuda_gpu_ids_.emplace_back(i);
} else {
LOG_WARNING << "GPU metrics will not be available for device:" << i;
}
}
// create a gpu group
char groupName[] = "dcgm_group";
dcgmerr = dcgmGroupCreate(
dcgm_metadata_.dcgm_handle_, DCGM_GROUP_DEFAULT, groupName,
&dcgm_metadata_.groupId_);
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "Cannot make GPU group: " << errorString(dcgmerr);
}
// Initialize tracking vectors
for (unsigned int didx = 0;
didx < dcgm_metadata_.available_cuda_gpu_ids_.size(); ++didx) {
dcgm_metadata_.power_limit_fail_cnt_.push_back(0);
dcgm_metadata_.power_usage_fail_cnt_.push_back(0);
dcgm_metadata_.energy_fail_cnt_.push_back(0);
dcgm_metadata_.util_fail_cnt_.push_back(0);
dcgm_metadata_.mem_fail_cnt_.push_back(0);
dcgm_metadata_.last_energy_.push_back(0);
}
// Number of fields for DCGM to use from fields_ below
dcgm_metadata_.field_count_ = 6;
unsigned short util_flag = dcgm_metadata_.standalone_
? DCGM_FI_PROF_GR_ENGINE_ACTIVE
: DCGM_FI_DEV_GPU_UTIL;
dcgm_metadata_.fields_ = {
DCGM_FI_DEV_POWER_MGMT_LIMIT, // power limit, watts
DCGM_FI_DEV_POWER_USAGE, // power usage, watts
DCGM_FI_DEV_TOTAL_ENERGY_CONSUMPTION, // Total energy consumption, mJ
util_flag, // util ratio, 1 = 1%
DCGM_FI_DEV_FB_USED, // Frame buffer used, MiB
DCGM_FI_DEV_FB_TOTAL, // Frame buffer used, MiB
};
char fieldName[] = "field_group";
dcgmFieldGrp_t fieldGroupId;
dcgmerr = dcgmFieldGroupCreate(
dcgm_metadata_.dcgm_handle_, dcgm_metadata_.field_count_,
dcgm_metadata_.fields_.data(), fieldName, &fieldGroupId);
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "Cannot make field group: " << errorString(dcgmerr);
}
dcgmerr = dcgmWatchFields(
dcgm_metadata_.dcgm_handle_, dcgm_metadata_.groupId_, fieldGroupId,
metrics_interval_ms_ * 1000 /*update period, usec*/,
5.0 /*maxKeepAge, sec*/, 5 /*maxKeepSamples*/);
if (dcgmerr != DCGM_ST_OK) {
LOG_WARNING << "Cannot start watching fields: " << errorString(dcgmerr);
return false;
}
return true;
#endif // TRITON_ENABLE_METRICS_GPU
}
#ifdef TRITON_ENABLE_METRICS_GPU
std::string
Metrics::dcgmValueToErrorMessage(double val)
{
if (DCGM_FP64_IS_BLANK(val)) {
if (val == DCGM_FP64_BLANK) {
return "Not Specified";
} else if (val == DCGM_FP64_NOT_FOUND) {
return "Not Found";
} else if (val == DCGM_FP64_NOT_SUPPORTED) {
return "Not Supported";
} else if (val == DCGM_FP64_NOT_PERMISSIONED) {
return "Insf. Permission";
} else {
return "Unknown";
}
} else {
return std::to_string(val);
}
}
std::string
Metrics::dcgmValueToErrorMessage(int64_t val)
{
if (DCGM_INT64_IS_BLANK(val)) {
switch (val) {
case DCGM_INT64_BLANK:
return "Not Specified";
case DCGM_INT64_NOT_FOUND:
return "Not Found";
case DCGM_INT64_NOT_SUPPORTED:
return "Not Supported";
case DCGM_INT64_NOT_PERMISSIONED:
return "Insf. Permission";
default:
return "Unknown";
}
} else {
return std::to_string(val);
}
}
#endif // TRITON_ENABLE_METRICS_GPU
bool
Metrics::UUIDForCudaDevice(int cuda_device, std::string* uuid)
{
// If metrics were not initialized then just silently fail since
// with DCGM we can't get the CUDA device (and not worth doing
// anyway since metrics aren't being reported).
auto singleton = GetSingleton();
if (!singleton->gpu_metrics_enabled_) {
return false;
}
// If GPU metrics is not enabled just silently fail.
#ifndef TRITON_ENABLE_METRICS_GPU
return false;
#else
dcgmDeviceAttributes_t gpu_attributes;
gpu_attributes.version = dcgmDeviceAttributes_version;
dcgmReturn_t dcgmerr = dcgmGetDeviceAttributes(
singleton->dcgm_metadata_.dcgm_handle_, cuda_device, &gpu_attributes);
if (dcgmerr != DCGM_ST_OK) {
LOG_ERROR << "Unable to get device UUID: " << errorString(dcgmerr);
return false;
}
*uuid = gpu_attributes.identifiers.uuid;
return true;
#endif // TRITON_ENABLE_METRICS_GPU
}
std::shared_ptr<prometheus::Registry>
Metrics::GetRegistry()
{
auto singleton = Metrics::GetSingleton();
return singleton->registry_;
}
const std::string
Metrics::SerializedMetrics()
{
auto singleton = Metrics::GetSingleton();
return singleton->serializer_->Serialize(
singleton->registry_.get()->Collect());
}
Metrics*
Metrics::GetSingleton()
{
static Metrics singleton;
return &singleton;
}
}} // namespace triton::core
#endif // TRITON_ENABLE_METRICS
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#ifdef TRITON_ENABLE_METRICS
#include <atomic>
#include <mutex>
#include <thread>
#include "prometheus/counter.h"
#include "prometheus/gauge.h"
#include "prometheus/registry.h"
#include "prometheus/serializer.h"
#include "prometheus/text_serializer.h"
#include "response_cache.h"
#ifdef TRITON_ENABLE_METRICS_GPU
#include <dcgm_agent.h>
#endif // TRITON_ENABLE_METRICS_GPU
namespace triton { namespace core {
#ifdef TRITON_ENABLE_METRICS_CPU
using MemInfo = std::unordered_map<std::string, uint64_t>;
// References:
// - htop source: https://stackoverflow.com/a/23376195
// - Linux docs: https://www.kernel.org/doc/Documentation/filesystems/proc.txt
// guest/guestnice values are counted in user/nice so we skip parsing them
struct CpuInfo {
uint64_t user = 0; // normal processes executing in user mode
uint64_t nice = 0; // niced processes executing in user mode
uint64_t system = 0; // processes executing in kernel mode
uint64_t idle = 0; // twiddling thumbs
uint64_t iowait = 0; // waiting for I/O to complete
uint64_t irq = 0; // servicing interrupts
uint64_t softirq = 0; // servicing softirqs
uint64_t steal = 0; // involuntary wait
};
inline std::istream&
operator>>(std::istream& is, CpuInfo& info)
{
is >> info.user >> info.nice >> info.system >> info.idle >> info.iowait >>
info.irq >> info.softirq >> info.steal;
return is;
}
#endif // TRITON_ENABLE_METRICS_CPU
#ifdef TRITON_ENABLE_METRICS_GPU
struct DcgmMetadata {
// DCGM handles for initialization and destruction
dcgmHandle_t dcgm_handle_ = 0;
dcgmGpuGrp_t groupId_ = 0;
// DCGM Flags
bool standalone_ = false;
// DCGM Fields
size_t field_count_ = 0;
std::vector<unsigned short> fields_;
// GPU Device Mapping
std::map<uint32_t, uint32_t> cuda_ids_to_dcgm_ids_;
std::vector<uint32_t> available_cuda_gpu_ids_;
// Stop attempting metrics if they fail multiple consecutive
// times for a device.
const int fail_threshold_ = 3;
// DCGM Failure Tracking
std::vector<int> power_limit_fail_cnt_;
std::vector<int> power_usage_fail_cnt_;
std::vector<int> energy_fail_cnt_;
std::vector<int> util_fail_cnt_;
std::vector<int> mem_fail_cnt_;
// DCGM Energy Tracking
std::vector<unsigned long long> last_energy_;
// Track if DCGM handle initialized successfully
bool dcgm_initialized_ = false;
};
#endif // TRITON_ENABLE_METRICS_GPU
class Metrics {
public:
// Return the hash value of the labels
static size_t HashLabels(const std::map<std::string, std::string>& labels);
// Are metrics enabled?
static bool Enabled();
// Enable reporting of metrics
static void EnableMetrics();
// Enable reporting of GPU metrics
static void EnableGPUMetrics();
// Enable reporting of CPU metrics
static void EnableCpuMetrics();
// Enable reporting of Cache metrics
static void EnableCacheMetrics(
std::shared_ptr<RequestResponseCache> response_cache);
// Start a thread for polling enabled metrics if any
static void StartPollingThreadSingleton(
std::shared_ptr<RequestResponseCache> response_cache);
// Set the time interval in secs at which metrics are collected
static void SetMetricsInterval(uint64_t metrics_interval_ms);
// Get the prometheus registry
static std::shared_ptr<prometheus::Registry> GetRegistry();
// Get serialized metrics
static const std::string SerializedMetrics();
// Get the UUID for a CUDA device. Return true and initialize 'uuid'
// if a UUID is found, return false if a UUID cannot be returned.
static bool UUIDForCudaDevice(int cuda_device, std::string* uuid);
// Metric family counting successful inference requests
static prometheus::Family<prometheus::Counter>& FamilyInferenceSuccess()
{
return GetSingleton()->inf_success_family_;
}
// Metric family counting failed inference requests
static prometheus::Family<prometheus::Counter>& FamilyInferenceFailure()
{
return GetSingleton()->inf_failure_family_;
}
// Metric family counting inferences performed, where a batch-size
// 'n' inference request is counted as 'n' inferences
static prometheus::Family<prometheus::Counter>& FamilyInferenceCount()
{
return GetSingleton()->inf_count_family_;
}
// Metric family counting inferences performed, where a batch-size
// 'n' inference request is counted as 'n' inferences
static prometheus::Family<prometheus::Counter>&
FamilyInferenceExecutionCount()
{
return GetSingleton()->inf_count_exec_family_;
}
// Metric family of cumulative inference request duration, in
// microseconds
static prometheus::Family<prometheus::Counter>&
FamilyInferenceRequestDuration()
{
return GetSingleton()->inf_request_duration_us_family_;
}
// Metric family of cumulative inference queuing duration, in
// microseconds
static prometheus::Family<prometheus::Counter>& FamilyInferenceQueueDuration()
{
return GetSingleton()->inf_queue_duration_us_family_;
}
// Metric family of cumulative inference compute durations, in
// microseconds
static prometheus::Family<prometheus::Counter>&
FamilyInferenceComputeInputDuration()
{
return GetSingleton()->inf_compute_input_duration_us_family_;
}
static prometheus::Family<prometheus::Counter>&
FamilyInferenceComputeInferDuration()
{
return GetSingleton()->inf_compute_infer_duration_us_family_;
}
static prometheus::Family<prometheus::Counter>&
FamilyInferenceComputeOutputDuration()
{
return GetSingleton()->inf_compute_output_duration_us_family_;
}
// Metric families of per-model response cache metrics
static prometheus::Family<prometheus::Counter>& FamilyCacheHitCount()
{
return GetSingleton()->cache_num_hits_model_family_;
}
static prometheus::Family<prometheus::Counter>& FamilyCacheHitLookupDuration()
{
return GetSingleton()->cache_hit_lookup_duration_us_model_family_;
}
static prometheus::Family<prometheus::Counter>& FamilyCacheMissCount()
{
return GetSingleton()->cache_num_misses_model_family_;
}
static prometheus::Family<prometheus::Counter>&
FamilyCacheMissLookupDuration()
{
return GetSingleton()->cache_miss_lookup_duration_us_model_family_;
}
static prometheus::Family<prometheus::Counter>&
FamilyCacheMissInsertionDuration()
{
return GetSingleton()->cache_miss_insertion_duration_us_model_family_;
}
private:
Metrics();
virtual ~Metrics();
static Metrics* GetSingleton();
bool InitializeDcgmMetrics();
bool InitializeCpuMetrics();
bool InitializeCacheMetrics(
std::shared_ptr<RequestResponseCache> response_cache);
bool StartPollingThread(std::shared_ptr<RequestResponseCache> response_cache);
bool PollCacheMetrics(std::shared_ptr<RequestResponseCache> response_cache);
bool PollDcgmMetrics();
bool PollCpuMetrics();
std::string dcgmValueToErrorMessage(double val);
std::string dcgmValueToErrorMessage(int64_t val);
std::shared_ptr<prometheus::Registry> registry_;
std::unique_ptr<prometheus::Serializer> serializer_;
prometheus::Family<prometheus::Counter>& inf_success_family_;
prometheus::Family<prometheus::Counter>& inf_failure_family_;
prometheus::Family<prometheus::Counter>& inf_count_family_;
prometheus::Family<prometheus::Counter>& inf_count_exec_family_;
prometheus::Family<prometheus::Counter>& inf_request_duration_us_family_;
prometheus::Family<prometheus::Counter>& inf_queue_duration_us_family_;
prometheus::Family<prometheus::Counter>&
inf_compute_input_duration_us_family_;
prometheus::Family<prometheus::Counter>&
inf_compute_infer_duration_us_family_;
prometheus::Family<prometheus::Counter>&
inf_compute_output_duration_us_family_;
// Global Response Cache metrics
prometheus::Family<prometheus::Gauge>& cache_num_entries_family_;
prometheus::Family<prometheus::Gauge>& cache_num_lookups_family_;
prometheus::Family<prometheus::Gauge>& cache_num_hits_family_;
prometheus::Family<prometheus::Gauge>& cache_num_misses_family_;
prometheus::Family<prometheus::Gauge>& cache_num_evictions_family_;
prometheus::Family<prometheus::Gauge>& cache_lookup_duration_us_family_;
prometheus::Family<prometheus::Gauge>& cache_insertion_duration_us_family_;
prometheus::Family<prometheus::Gauge>& cache_util_family_;
// Gauges for Global Response Cache metrics
prometheus::Gauge* cache_num_entries_global_;
prometheus::Gauge* cache_num_lookups_global_;
prometheus::Gauge* cache_num_hits_global_;
prometheus::Gauge* cache_num_misses_global_;
prometheus::Gauge* cache_num_evictions_global_;
prometheus::Gauge* cache_lookup_duration_us_global_;
prometheus::Gauge* cache_insertion_duration_us_global_;
prometheus::Gauge* cache_util_global_;
// Per-model Response Cache metrics
prometheus::Family<prometheus::Counter>& cache_num_hits_model_family_;
prometheus::Family<prometheus::Counter>&
cache_hit_lookup_duration_us_model_family_;
prometheus::Family<prometheus::Counter>& cache_num_misses_model_family_;
prometheus::Family<prometheus::Counter>&
cache_miss_lookup_duration_us_model_family_;
prometheus::Family<prometheus::Counter>&
cache_miss_insertion_duration_us_model_family_;
#ifdef TRITON_ENABLE_METRICS_GPU
prometheus::Family<prometheus::Gauge>& gpu_utilization_family_;
prometheus::Family<prometheus::Gauge>& gpu_memory_total_family_;
prometheus::Family<prometheus::Gauge>& gpu_memory_used_family_;
prometheus::Family<prometheus::Gauge>& gpu_power_usage_family_;
prometheus::Family<prometheus::Gauge>& gpu_power_limit_family_;
prometheus::Family<prometheus::Counter>& gpu_energy_consumption_family_;
std::vector<prometheus::Gauge*> gpu_utilization_;
std::vector<prometheus::Gauge*> gpu_memory_total_;
std::vector<prometheus::Gauge*> gpu_memory_used_;
std::vector<prometheus::Gauge*> gpu_power_usage_;
std::vector<prometheus::Gauge*> gpu_power_limit_;
std::vector<prometheus::Counter*> gpu_energy_consumption_;
DcgmMetadata dcgm_metadata_;
#endif // TRITON_ENABLE_METRICS_GPU
#ifdef TRITON_ENABLE_METRICS_CPU
// Parses "/proc/meminfo" for metrics, currently only supported on Linux.
Status ParseMemInfo(MemInfo& info);
// Parses "/proc/stat" for metrics, currently only supported on Linux.
Status ParseCpuInfo(CpuInfo& info);
// Computes CPU utilization between "info_new" and "info_old" values
double CpuUtilization(const CpuInfo& info_new, const CpuInfo& info_old);
prometheus::Family<prometheus::Gauge>& cpu_utilization_family_;
prometheus::Family<prometheus::Gauge>& cpu_memory_total_family_;
prometheus::Family<prometheus::Gauge>& cpu_memory_used_family_;
prometheus::Gauge* cpu_utilization_;
prometheus::Gauge* cpu_memory_total_;
prometheus::Gauge* cpu_memory_used_;
CpuInfo last_cpu_info_;
#endif // TRITON_ENABLE_METRICS_CPU
// Thread for polling cache/gpu metrics periodically
std::unique_ptr<std::thread> poll_thread_;
std::atomic<bool> poll_thread_exit_;
bool metrics_enabled_;
bool gpu_metrics_enabled_;
bool cpu_metrics_enabled_;
bool cache_metrics_enabled_;
bool poll_thread_started_;
std::mutex metrics_enabling_;
std::mutex poll_thread_starting_;
uint64_t metrics_interval_ms_;
};
}} // namespace triton::core
#endif // TRITON_ENABLE_METRICS
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "model.h"
#include <chrono>
#include <future>
#include "constants.h"
#include "filesystem.h"
#include "infer_request.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
Status
Model::GetInput(
const std::string& name, const inference::ModelInput** input) const
{
const auto itr = input_map_.find(name);
if (itr == input_map_.end()) {
return Status(
Status::Code::INVALID_ARG,
"unexpected inference input '" + name + "' for model '" + Name() + "'");
}
*input = &itr->second;
return Status::Success;
}
Status
Model::GetOutput(
const std::string& name, const inference::ModelOutput** output) const
{
const auto itr = output_map_.find(name);
if (itr == output_map_.end()) {
return Status(
Status::Code::INVALID_ARG, "unexpected inference output '" + name +
"' for model '" + Name() + "'");
}
*output = &itr->second;
return Status::Success;
}
Status
Model::SetModelConfig(const inference::ModelConfig& config)
{
config_ = config;
set_model_config_ = true;
return Status::Success;
}
Status
Model::SetScheduler(std::unique_ptr<Scheduler> scheduler)
{
if (scheduler_ != nullptr) {
return Status(
Status::Code::INTERNAL, "Attempt to change scheduler not allowed");
}
scheduler_ = std::move(scheduler);
return Status::Success;
}
Status
Model::Init(const bool is_config_provided)
{
if (!set_model_config_ && !is_config_provided) {
return Status(
Status::Code::NOT_FOUND,
"model configuration is not provided for model '" + Name() + "'");
}
RETURN_IF_ERROR(ValidateModelConfig(config_, min_compute_capability_));
RETURN_IF_ERROR(ValidateModelIOConfig(config_));
// Initialize the input map
for (const auto& io : config_.input()) {
input_map_.insert(std::make_pair(io.name(), io));
if (!io.optional()) {
++required_input_count_;
}
}
// Initialize the output map and label provider for each output
label_provider_ = std::make_shared<LabelProvider>();
for (const auto& io : config_.output()) {
output_map_.insert(std::make_pair(io.name(), io));
if (!io.label_filename().empty()) {
const auto label_path = JoinPath({model_dir_, io.label_filename()});
RETURN_IF_ERROR(label_provider_->AddLabels(io.name(), label_path));
}
}
if (config_.has_dynamic_batching()) {
default_priority_level_ =
config_.dynamic_batching().default_priority_level();
max_priority_level_ = config_.dynamic_batching().priority_levels();
} else if (config_.has_ensemble_scheduling()) {
// For ensemble, allow any priority level to pass through
default_priority_level_ = 0;
max_priority_level_ = UINT32_MAX;
} else {
default_priority_level_ = 0;
max_priority_level_ = 0;
}
return Status::Success;
}
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "infer_stats.h"
#include "label_provider.h"
#include "model_config.pb.h"
#include "scheduler.h"
#include "status.h"
namespace triton { namespace core {
class InferenceRequest;
//
// Interface for models that handle inference requests.
//
class Model {
public:
explicit Model(
const double min_compute_capability, const std::string& model_dir,
const int64_t version, const inference::ModelConfig& config)
: config_(config), min_compute_capability_(min_compute_capability),
version_(version), required_input_count_(0), model_dir_(model_dir),
set_model_config_(false)
{
}
virtual ~Model() {}
// Get the name of model being served.
const std::string& Name() const { return config_.name(); }
// Get the version of model being served.
int64_t Version() const { return version_; }
// Get the configuration of model being served.
const inference::ModelConfig& Config() const { return config_; }
// Get the number of required inputs
size_t RequiredInputCount() const { return required_input_count_; }
// Get the stats collector for the model being served.
InferenceStatsAggregator* MutableStatsAggregator()
{
return &stats_aggregator_;
}
const InferenceStatsAggregator& StatsAggregator() const
{
return stats_aggregator_;
}
// Get the model configuration for a named input.
Status GetInput(
const std::string& name, const inference::ModelInput** input) const;
// Get the model configuration for a named output.
Status GetOutput(
const std::string& name, const inference::ModelOutput** output) const;
// Get a label provider for the model.
const std::shared_ptr<LabelProvider>& GetLabelProvider() const
{
return label_provider_;
}
// Initialize the instance for Triton core usage
Status Init(const bool is_config_provided);
// Enqueue a request for execution. If Status::Success is returned
// then the model has taken ownership of the request object and so
// 'request' will be nullptr. If non-success is returned then the
// caller still retains ownership of 'request'.
Status Enqueue(std::unique_ptr<InferenceRequest>& request)
{
return scheduler_->Enqueue(request);
}
// Return the number of in-flight inferences.
size_t InflightInferenceCount()
{
return scheduler_->InflightInferenceCount();
}
// Stop processing future requests unless they are considered as in-flight.
void Stop() { scheduler_->Stop(); }
uint32_t DefaultPriorityLevel() const { return default_priority_level_; }
uint32_t MaxPriorityLevel() const { return max_priority_level_; }
protected:
// Set the configuration of the model being served.
Status SetModelConfig(const inference::ModelConfig& config);
// Explicitly set the scheduler to use for inference requests to the
// model. The scheduler can only be set once for a model.
Status SetScheduler(std::unique_ptr<Scheduler> scheduler);
// The scheduler to use for this model.
std::unique_ptr<Scheduler> scheduler_;
// Configuration of the model.
inference::ModelConfig config_;
private:
// The minimum supported CUDA compute capability.
const double min_compute_capability_;
// Version of the model.
int64_t version_;
// The stats collector for the model.
InferenceStatsAggregator stats_aggregator_;
// Label provider for this model.
std::shared_ptr<LabelProvider> label_provider_;
size_t required_input_count_;
// Map from input name to the model configuration for that input.
std::unordered_map<std::string, inference::ModelInput> input_map_;
// Map from output name to the model configuration for that output.
std::unordered_map<std::string, inference::ModelOutput> output_map_;
// Path to model
std::string model_dir_;
// The default priority level for the model.
uint32_t default_priority_level_;
// The largest priority value for the model.
uint32_t max_priority_level_;
// Whether or not model config has been set.
bool set_model_config_;
};
}} // namespace triton::core
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "model_config_cuda.h"
#include <cuda_runtime_api.h>
namespace triton { namespace core {
int
GetCudaStreamPriority(
inference::ModelOptimizationPolicy::ModelPriority priority)
{
// Default priority is 0
int cuda_stream_priority = 0;
int min, max;
cudaError_t cuerr = cudaDeviceGetStreamPriorityRange(&min, &max);
if ((cuerr != cudaErrorNoDevice) && (cuerr != cudaSuccess)) {
return 0;
}
switch (priority) {
case inference::ModelOptimizationPolicy::PRIORITY_MAX:
cuda_stream_priority = max;
break;
case inference::ModelOptimizationPolicy::PRIORITY_MIN:
cuda_stream_priority = min;
break;
default:
cuda_stream_priority = 0;
break;
}
return cuda_stream_priority;
}
}} // namespace triton::core
// Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <stdint.h>
#include "model_config.pb.h"
namespace triton { namespace core {
/// Get the CUDA stream priority for a given ModelPriority
/// \param priority The inference::ModelOptimizationPolicy::ModelPriority
/// priority. \param cuda_stream_priority Returns the CUDA stream priority.
/// \return The error status.
int GetCudaStreamPriority(
inference::ModelOptimizationPolicy::ModelPriority priority);
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "model_config_utils.h"
#include <google/protobuf/util/json_util.h>
#include <deque>
#include <mutex>
#include <set>
#include "constants.h"
#include "cuda_utils.h"
#include "filesystem.h"
#include "triton/common/logging.h"
#define TRITONJSON_STATUSTYPE triton::core::Status
#define TRITONJSON_STATUSRETURN(M) \
return triton::core::Status(triton::core::Status::Code::INTERNAL, (M))
#define TRITONJSON_STATUSSUCCESS triton::core::Status::Success
#include "triton/common/triton_json.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace triton { namespace core {
namespace {
#ifdef TRITON_ENABLE_ENSEMBLE
struct EnsembleTensor {
EnsembleTensor(bool isOutput) : ready(false), isOutput(isOutput) {}
bool ready;
bool isOutput;
std::vector<EnsembleTensor*> prev_nodes;
std::vector<EnsembleTensor*> next_nodes;
};
/// Build a graph that represents the data flow in the ensemble specified in
/// given model config. the node (ensemble tensor) in the graph can be looked
/// up using its name as key.
/// \param ensemble_config The model configuration that specifies
/// ensemble_scheduling field.
/// \param keyed_ensemble_graph Returned the ensemble graph.
/// \return The error status. A non-OK status indicates the build fails because
/// the ensemble configuration is not valid.
Status
BuildEnsembleGraph(
const inference::ModelConfig& config,
std::unordered_map<std::string, EnsembleTensor>& keyed_ensemble_graph)
{
keyed_ensemble_graph.clear();
size_t step_idx = 0;
for (const auto& element : config.ensemble_scheduling().step()) {
if (element.model_name().empty()) {
return Status(
Status::Code::INVALID_ARG,
"must specify 'model_name' in step " + std::to_string(step_idx) +
" of ensemble '" + config.name() + "'");
}
if (element.input_map().size() == 0) {
return Status(
Status::Code::INVALID_ARG,
"must specify 'input_map' in step " + std::to_string(step_idx) +
" of ensemble '" + config.name() + "'");
}
if (element.output_map().size() == 0) {
return Status(
Status::Code::INVALID_ARG,
"must specify 'output_map' in step " + std::to_string(step_idx) +
" of ensemble '" + config.name() + "'");
}
// Link ensemble tensors
std::vector<EnsembleTensor*> tensor_as_output;
for (const auto& output_map : element.output_map()) {
auto it = keyed_ensemble_graph.find(output_map.second);
if (it != keyed_ensemble_graph.end()) {
if (it->second.isOutput) {
return Status(
Status::Code::INVALID_ARG,
"ensemble tensor '" + it->first +
"' can appear in an output map only once for ensemble '" +
config.name() + "' step " + std::to_string(step_idx));
} else {
it->second.isOutput = true;
}
} else {
it = keyed_ensemble_graph
.emplace(
std::make_pair(output_map.second, EnsembleTensor(true)))
.first;
}
tensor_as_output.push_back(&(it->second));
}
std::set<std::string> model_inputs;
for (const auto& input_map : element.input_map()) {
if (model_inputs.find(input_map.first) != model_inputs.end()) {
return Status(
Status::Code::INVALID_ARG,
"input '" + input_map.first + "' in model '" +
element.model_name() +
"' is mapped to multiple ensemble tensors for ensemble '" +
config.name() + "' step " + std::to_string(step_idx));
} else {
model_inputs.emplace(input_map.first);
}
auto it = keyed_ensemble_graph.find(input_map.second);
if (it == keyed_ensemble_graph.end()) {
it = keyed_ensemble_graph
.emplace(
std::make_pair(input_map.second, EnsembleTensor(false)))
.first;
}
for (auto output : tensor_as_output) {
output->prev_nodes.push_back(&(it->second));
it->second.next_nodes.push_back(output);
}
}
step_idx++;
}
return Status::Success;
}
Status
ValidateEnsembleSchedulingConfig(const inference::ModelConfig& config)
{
if (config.platform() != kEnsemblePlatform) {
return Status(
Status::Code::INVALID_ARG,
"ensemble scheduling cannot be set for model '" + config.name() +
"' whose platform is not " + kEnsemblePlatform);
}
if (config.instance_group().size() != 0) {
return Status(
Status::Code::INVALID_ARG,
"instance group should not be specified for ensemble '" +
config.name() + "'");
}
if (config.has_optimization()) {
return Status(
Status::Code::INVALID_ARG,
"optimization should not be specified for ensemble '" + config.name() +
"'");
}
if (config.model_warmup_size() != 0) {
return Status(
Status::Code::INVALID_ARG,
"model_warmup can not be specified for ensemble '" + config.name() +
"'");
}
// Make sure step is not empty and all fields are set
if (config.ensemble_scheduling().step_size() == 0) {
return Status(
Status::Code::INVALID_ARG,
"must specify 'step' for ensemble '" + config.name() + "'");
}
std::unordered_map<std::string, EnsembleTensor> tensors;
RETURN_IF_ERROR(BuildEnsembleGraph(config, tensors));
// check data flow
std::deque<EnsembleTensor*> ready_queue;
for (const auto& input : config.input()) {
auto it = tensors.find(input.name());
if (it == tensors.end()) {
return Status(
Status::Code::INVALID_ARG, "ensemble input '" + input.name() +
"' for ensemble " + config.name() +
"' is not used");
}
it->second.ready = true;
ready_queue.push_back(&(it->second));
}
while (!ready_queue.empty()) {
auto& ready_node = ready_queue.front();
for (auto& next_node : ready_node->next_nodes) {
if (next_node->ready) {
continue;
}
bool next_node_ready = true;
for (auto& prev_node : next_node->prev_nodes) {
if (!prev_node->ready) {
next_node_ready = false;
break;
}
}
next_node->ready = next_node_ready;
if (next_node_ready) {
ready_queue.push_back(next_node);
}
}
ready_queue.pop_front();
}
std::set<std::string> outputs;
for (const auto& output : config.output()) {
auto it = tensors.find(output.name());
if (it == tensors.end()) {
return Status(
Status::Code::INVALID_ARG, "ensemble output '" + output.name() +
"' for ensemble " + config.name() +
"' is not used");
}
if (!it->second.ready) {
return Status(
Status::Code::INVALID_ARG, "output '" + output.name() +
"' for ensemble '" + config.name() +
"' is not written");
} else {
outputs.insert(it->first);
}
}
// Check redundant ensemble tensors
for (const auto& tensor : tensors) {
// skip ensemble outputs as they have been checked and can have no
// next nodes
if (outputs.find(tensor.first) != outputs.end()) {
continue;
}
if (!tensor.second.ready || (tensor.second.next_nodes.size() == 0)) {
return Status(
Status::Code::INVALID_ARG, "ensemble tensor '" + tensor.first +
"' is unused in ensemble '" +
config.name() + "'");
}
}
return Status::Success;
}
#endif // TRITON_ENABLE_ENSEMBLE
template <class ModelIO>
Status
ValidateIOShape(
const ModelIO& io, int32_t max_batch_size,
const std::string& message_prefix = "")
{
if (io.name().empty()) {
return Status(
Status::Code::INVALID_ARG, message_prefix + "must specify 'name'");
}
if (io.data_type() == inference::DataType::TYPE_INVALID) {
return Status(
Status::Code::INVALID_ARG, "model output must specify 'data_type'");
}
if (io.dims_size() == 0) {
return Status(
Status::Code::INVALID_ARG, message_prefix + "must specify 'dims'");
}
// If the configuration is non-batching, then no input or output
// reshape can be empty as that would mean that input or output was
// always empty (no data).
if (io.has_reshape() && (io.reshape().shape_size() == 0) &&
(max_batch_size == 0)) {
return Status(
Status::Code::INVALID_ARG,
message_prefix +
"cannot have empty reshape for non-batching model as scalar "
"tensors are not supported");
}
for (auto dim : io.dims()) {
// Dimension cannot be 0.
if ((dim < 1) && (dim != triton::common::WILDCARD_DIM)) {
return Status(
Status::Code::INVALID_ARG,
message_prefix + "dimension must be integer >= 1, or " +
std::to_string(triton::common::WILDCARD_DIM) +
" to indicate a variable-size dimension");
}
}
if (io.has_reshape()) {
// Zeros are not allowed in reshape.
for (auto dim : io.reshape().shape()) {
if ((dim < 1) && (dim != triton::common::WILDCARD_DIM)) {
return Status(
Status::Code::INVALID_ARG,
message_prefix + "reshape dimensions must be integer >= 1, or " +
std::to_string(triton::common::WILDCARD_DIM) +
" to indicate a variable-size dimension");
}
}
const int64_t dims_size = triton::common::GetElementCount(io.dims());
const int64_t reshape_size =
triton::common::GetElementCount(io.reshape().shape());
// dims and reshape must both have same element count
// or both have variable-size dimension.
// Special case for empty reshape... expect dims to have element
// count of 1.
if ((dims_size != reshape_size) &&
((reshape_size != 0) || (dims_size != 1))) {
return Status(
Status::Code::INVALID_ARG,
message_prefix + "has different size for dims and reshape");
}
// shape contains variable-size dimension, in this case we compare if
// each pair of the trunks separated by variable-size dimension has
// the same element count. For instance, from [2, 4, -1, 6] to [8, -1, 1, 6]
// is valid reshape as 2 * 4 = 8 and 6 = 1 * 6.
if (dims_size == -1) {
std::vector<int64_t> dim_element_cnts;
std::vector<int64_t> reshape_element_cnts;
int64_t current_cnt = 1;
for (const auto& dim : io.dims()) {
if (dim != -1) {
current_cnt *= dim;
} else {
dim_element_cnts.push_back(current_cnt);
current_cnt = 1;
}
}
dim_element_cnts.push_back(current_cnt);
current_cnt = 1;
for (const auto& dim : io.reshape().shape()) {
if (dim != -1) {
current_cnt *= dim;
} else {
reshape_element_cnts.push_back(current_cnt);
current_cnt = 1;
}
}
reshape_element_cnts.push_back(current_cnt);
if (dim_element_cnts.size() != reshape_element_cnts.size()) {
return Status(
Status::Code::INVALID_ARG,
message_prefix +
"has different number of variable-size dimensions for dims "
"and reshape");
}
for (size_t idx = 0; idx < dim_element_cnts.size(); idx++) {
if (dim_element_cnts[idx] != reshape_element_cnts[idx]) {
return Status(
Status::Code::INVALID_ARG,
message_prefix + "has different size for dims and reshape");
}
}
}
}
return Status::Success;
}
} // namespace
Status
GetModelVersionFromPath(const std::string& path, int64_t* version)
{
auto version_dir = BaseName(path);
// Determine the version from the last segment of 'path'
try {
*version = std::atoll(version_dir.c_str());
}
catch (...) {
return Status(
Status::Code::INTERNAL,
"unable to determine model version from " + path);
}
return Status::Success;
}
Status
GetBooleanSequenceControlProperties(
const inference::ModelSequenceBatching& batcher,
const std::string& model_name,
const inference::ModelSequenceBatching::Control::Kind control_kind,
const bool required, std::string* tensor_name,
inference::DataType* tensor_datatype, float* fp32_false_value,
float* fp32_true_value, int32_t* int32_false_value,
int32_t* int32_true_value, bool* bool_false_value, bool* bool_true_value)
{
// Make sure same tensor is not configured for multiple controls
std::set<std::string> seen_tensors;
// Make sure the control kind is not mentioned multiple times.
bool seen_control = false;
for (const auto& control_input : batcher.control_input()) {
if (control_input.name().empty()) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control tensor must have a name for " +
model_name);
}
if (seen_tensors.find(control_input.name()) != seen_tensors.end()) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control tensor '" + control_input.name() +
"' is specified for multiple control kinds for " + model_name);
}
seen_tensors.insert(control_input.name());
for (const auto& c : control_input.control()) {
if (c.kind() == control_kind) {
if (seen_control) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching specifies multiple " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" tensors for " + model_name);
}
*tensor_name = control_input.name();
seen_control = true;
// Make sure only one of int, float, or bool type is specified.
if (!((c.int32_false_true_size() != 0) ||
(c.fp32_false_true_size() != 0) ||
(c.bool_false_true_size() != 0))) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching must specify either 'int32_false_true', "
"'fp32_false_true' or 'bool_false_true' for " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" for " + model_name);
} else if (
((c.int32_false_true_size() != 0) &&
(c.fp32_false_true_size() != 0)) ||
((c.int32_false_true_size() != 0) &&
(c.bool_false_true_size() != 0)) ||
((c.fp32_false_true_size() != 0) &&
(c.bool_false_true_size() != 0))) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching specifies more than one from "
"'int32_false_true', 'fp32_false_true' and 'bool_false_true' "
"for " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" for " + model_name);
}
if (c.int32_false_true_size() > 0) {
if (c.int32_false_true_size() != 2) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control 'int32_false_true' must have "
"exactly 2 entries for " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" for " + model_name);
}
if (tensor_datatype != nullptr) {
*tensor_datatype = inference::DataType::TYPE_INT32;
}
if (int32_false_value != nullptr) {
*int32_false_value = c.int32_false_true(0);
}
if (int32_true_value != nullptr) {
*int32_true_value = c.int32_false_true(1);
}
} else if (c.fp32_false_true_size() > 0) {
if (c.fp32_false_true_size() != 2) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control 'fp32_false_true' must have exactly "
"2 entries for " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" for " + model_name);
}
if (tensor_datatype != nullptr) {
*tensor_datatype = inference::DataType::TYPE_FP32;
}
if (fp32_false_value != nullptr) {
*fp32_false_value = c.fp32_false_true(0);
}
if (fp32_true_value != nullptr) {
*fp32_true_value = c.fp32_false_true(1);
}
} else {
if (c.bool_false_true_size() != 2) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control 'bool_false_true' must have exactly "
"2 entries for " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" for " + model_name);
}
if (tensor_datatype != nullptr) {
*tensor_datatype = inference::DataType::TYPE_BOOL;
}
if (bool_false_value != nullptr) {
*bool_false_value = c.bool_false_true(0);
}
if (bool_true_value != nullptr) {
*bool_true_value = c.bool_false_true(1);
}
}
}
}
}
if (!seen_control) {
if (required) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control tensor must specify a " +
inference::ModelSequenceBatching_Control_Kind_Name(control_kind) +
" value for " + model_name);
}
tensor_name->clear();
}
return Status::Success;
}
Status
GetTypedSequenceControlProperties(
const inference::ModelSequenceBatching& batcher,
const std::string& model_name,
const inference::ModelSequenceBatching::Control::Kind control_kind,
const bool required, std::string* tensor_name,
inference::DataType* tensor_datatype)
{
// Make sure same tensor is not configured for multiple controls
std::set<std::string> seen_tensors;
// Make sure the control kind is not mentioned multiple times.
bool seen_control = false;
for (const auto& control_input : batcher.control_input()) {
if (control_input.name().empty()) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control tensor must have a name for " +
model_name);
}
if (seen_tensors.find(control_input.name()) != seen_tensors.end()) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control tensor '" + control_input.name() +
"' is specified for multiple control kinds for " + model_name);
}
seen_tensors.insert(control_input.name());
for (const auto& c : control_input.control()) {
if (c.kind() == control_kind) {
if (seen_control) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching specifies multiple " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" tensors for " + model_name);
}
*tensor_name = control_input.name();
if (tensor_datatype != nullptr) {
*tensor_datatype = c.data_type();
}
seen_control = true;
if ((c.int32_false_true_size() > 0) || (c.fp32_false_true_size() > 0) ||
(c.bool_false_true_size() > 0)) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching must not specify either 'int32_false_true', "
"'fp32_false_true' or 'bool_false_true' for " +
inference::ModelSequenceBatching_Control_Kind_Name(
control_kind) +
" for " + model_name);
}
}
}
}
if (!seen_control) {
if (required) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching control tensor must specify a " +
inference::ModelSequenceBatching_Control_Kind_Name(control_kind) +
" value for " + model_name);
}
tensor_name->clear();
}
return Status::Success;
}
Status
GetNormalizedModelConfig(
const std::string& model_name, const std::string& path,
const double min_compute_capability, inference::ModelConfig* config)
{
// Server-side autofill only sets certain backend fields for the models that
// belong to limited backends for backwards-compatibility. See TensorRT
// backend, ONNX Runtime backend, OpenVINO backend, TensorFLow backend, and
// PyTorch backend.
// Extracting detailed information is delegated to the backend implementation
// to auto-complete.
RETURN_IF_ERROR(
AutoCompleteBackendFields(model_name, std::string(path), config));
LOG_VERBOSE(1) << "Server side auto-completed config: "
<< config->DebugString();
RETURN_IF_ERROR(NormalizeModelConfig(min_compute_capability, config));
return Status::Success;
}
Status
NormalizeModelConfig(
const double min_compute_capability, inference::ModelConfig* config)
{
// If version_policy is not specified, default to Latest 1 version.
if (!config->has_version_policy()) {
inference::ModelVersionPolicy::Latest latest;
latest.set_num_versions(1);
config->mutable_version_policy()->mutable_latest()->CopyFrom(latest);
}
// If dynamic batching is specified...
if (config->has_dynamic_batching()) {
// If preferred batch size is not specified set it to
// max-batch-size.
if (config->dynamic_batching().preferred_batch_size().size() == 0) {
auto mutable_preferred_batch_size =
config->mutable_dynamic_batching()->mutable_preferred_batch_size();
if (config->max_batch_size() > 0) {
mutable_preferred_batch_size->Add(config->max_batch_size());
}
}
}
// If sequence batching is specified...
if (config->has_sequence_batching()) {
// Set default idle is not specified.
if (config->sequence_batching().max_sequence_idle_microseconds() == 0) {
config->mutable_sequence_batching()->set_max_sequence_idle_microseconds(
SEQUENCE_IDLE_DEFAULT_MICROSECONDS);
}
if (config->sequence_batching().has_oldest()) {
// If preferred batch size is not specified set it to
// max-batch-size.
if (config->sequence_batching().oldest().preferred_batch_size().size() ==
0) {
auto mutable_preferred_batch_size =
config->mutable_sequence_batching()
->mutable_oldest()
->mutable_preferred_batch_size();
if (config->max_batch_size() > 0) {
mutable_preferred_batch_size->Add(config->max_batch_size());
}
}
}
}
// If model ensembling is specified, don't attempt to normalize instance_group
// as it is not allowed in ensemble scheduling
if (!config->has_ensemble_scheduling()) {
auto optimization = config->mutable_optimization();
if (!optimization->has_input_pinned_memory()) {
optimization->mutable_input_pinned_memory()->set_enable(true);
}
if (!optimization->has_output_pinned_memory()) {
optimization->mutable_output_pinned_memory()->set_enable(true);
}
}
return Status::Success;
}
Status
NormalizeInstanceGroup(
const double min_compute_capability,
const std::vector<inference::ModelInstanceGroup>& preferred_groups,
inference::ModelConfig* config)
{
// Instance group setting doesn't apply to ensemble
if (config->has_ensemble_scheduling()) {
return Status::Success;
}
// Creates a set of supported GPU device ids
std::set<int> supported_gpus;
#ifdef TRITON_ENABLE_GPU
// Get the total number of GPUs from the runtime library.
Status status = GetSupportedGPUs(&supported_gpus, min_compute_capability);
if (!status.IsOk()) {
return status;
}
#endif // TRITON_ENABLE_GPU
// Make sure there is at least one instance_group.
if (config->instance_group().empty()) {
inference::ModelInstanceGroup* group = config->add_instance_group();
group->set_name(config->name());
for (const auto& pg : preferred_groups) {
group->set_kind(pg.kind());
group->set_count(pg.count());
// handle preferred GPU setting differently based on kind
if (pg.kind() == inference::ModelInstanceGroup::KIND_GPU) {
// Don't use preferred group with KIND_GPU if there is no GPU.
if (supported_gpus.empty()) {
continue;
}
// If preferred group sets GPUs, limit deployment onto those that
// are also listed in supported gpus
if (!pg.gpus().empty()) {
for (const int32_t gid : pg.gpus()) {
if (supported_gpus.find(gid) != supported_gpus.end()) {
group->add_gpus(gid);
}
}
}
break;
} else if (pg.kind() == inference::ModelInstanceGroup::KIND_AUTO) {
// if AUTO, then set preferred GPU as is, to align with KIND_AUTO
// deduction specified below
for (const int32_t gid : pg.gpus()) {
group->add_gpus(gid);
}
break;
}
// Other kind should not set GPUs
break;
}
}
// Assign default name, kind and count to each instance group that
// doesn't give those values explicitly. For KIND_GPU, set GPUs to
// all available if not specified explicitly.
size_t cnt = 0;
for (auto& group : *config->mutable_instance_group()) {
// Name
if (group.name().empty()) {
group.set_name(config->name() + "_" + std::to_string(cnt));
}
cnt++;
// For KIND_AUTO... if there are no GPUs or if any of the listed
// 'gpu's are not present, then use KIND_CPU.
if (group.kind() == inference::ModelInstanceGroup::KIND_AUTO) {
if (supported_gpus.empty()) {
group.set_kind(inference::ModelInstanceGroup::KIND_CPU);
} else {
for (const int32_t gid : group.gpus()) {
if (supported_gpus.find(gid) == supported_gpus.end()) {
group.set_kind(inference::ModelInstanceGroup::KIND_CPU);
break;
}
}
}
if (group.kind() == inference::ModelInstanceGroup::KIND_AUTO) {
group.set_kind(inference::ModelInstanceGroup::KIND_GPU);
}
}
// KIND is resolved at this point
for (const auto& pg : preferred_groups) {
if (group.kind() != pg.kind()) {
continue;
}
// Limit the GPU setting within what is specified in the preferred group,
// if no available GPU then skip to next preferred group
if ((group.kind() == inference::ModelInstanceGroup::KIND_GPU) &&
group.gpus().empty() && !pg.gpus().empty()) {
for (const int32_t gid : pg.gpus()) {
if (supported_gpus.find(gid) != supported_gpus.end()) {
group.add_gpus(gid);
}
}
if (group.gpus().empty()) {
continue;
}
}
if ((group.count() < 1) && (pg.count() > 0)) {
group.set_count(pg.count());
}
}
// Set Triton default if the fields are not set from preferred group
// Count
if (group.count() < 1) {
RETURN_IF_ERROR(SetDefaultInstanceCount(&group, config->backend()));
}
// GPUs
if ((group.kind() == inference::ModelInstanceGroup::KIND_GPU) &&
(group.gpus().size() == 0)) {
for (auto d : supported_gpus) {
group.add_gpus(d);
}
}
}
return Status::Success;
}
Status
LocalizePythonBackendExecutionEnvironmentPath(
const std::string& model_path, inference::ModelConfig* config,
std::shared_ptr<LocalizedPath>* localized_model_dir)
{
if (config->backend() == "python") {
if (config->parameters().contains("EXECUTION_ENV_PATH")) {
// Read EXECUTION_ENV_PATH
std::string exec_env_path =
config->parameters().at("EXECUTION_ENV_PATH").string_value();
// Replace model directory variable with model_path
std::string model_dir_var = "$$TRITON_MODEL_DIRECTORY";
if (exec_env_path.substr(0, model_dir_var.size()) == model_dir_var) {
exec_env_path.replace(0, model_dir_var.size(), model_path);
}
// Collapse any .. in the path
std::string abs_exec_env_path;
std::size_t prev_pos = exec_env_path.size();
std::size_t pos = exec_env_path.find_last_of('/', prev_pos - 1);
int skip = 0;
while (pos != std::string::npos && prev_pos > 0) {
if (!skip) {
abs_exec_env_path =
exec_env_path.substr(pos, prev_pos - pos) + abs_exec_env_path;
}
skip = skip > 0 ? skip - 1 : skip;
if (pos >= 3 && exec_env_path.substr(pos - 3, 3) == "/..") {
skip += 2;
}
prev_pos = pos;
pos = exec_env_path.find_last_of('/', prev_pos - 1);
}
abs_exec_env_path = exec_env_path.substr(0, prev_pos) + abs_exec_env_path;
// Localize iff abs_exec_env_path is outside the model directory
std::string model_path_slash =
model_path.back() == '/' ? model_path : model_path + "/";
if (abs_exec_env_path.substr(0, model_path_slash.size()) !=
model_path_slash) {
// Localize the file
std::shared_ptr<LocalizedPath> localized_exec_env_path;
RETURN_IF_ERROR(
LocalizePath(abs_exec_env_path, &localized_exec_env_path));
// Persist the localized temporary path
(*localized_model_dir)
->other_localized_path.push_back(localized_exec_env_path);
// Rewrite EXECUTION_ENV_PATH
config->mutable_parameters()
->at("EXECUTION_ENV_PATH")
.set_string_value(localized_exec_env_path->Path());
}
}
}
return Status::Success;
}
Status
SetDefaultInstanceCount(
inference::ModelInstanceGroup* group, const std::string& backend)
{
group->set_count(1);
// Backends opt into the default_cpu_instance_count since
// some backends (pytorch, OpenVINO) don't perform well/have high overhead
// when using multiple instances.
const int default_cpu_instance_count = 2;
bool use_default_cpu_instance_count =
(backend == kTensorFlowBackend) || (backend == kOnnxRuntimeBackend);
if (group->kind() == inference::ModelInstanceGroup::KIND_CPU &&
use_default_cpu_instance_count) {
group->set_count(default_cpu_instance_count);
}
return Status::Success;
}
Status
AutoCompleteBackendFields(
const std::string& model_name, const std::string& model_path,
inference::ModelConfig* config)
{
std::set<std::string> version_dirs;
RETURN_IF_ERROR(GetDirectorySubdirs(model_path, &version_dirs));
// There must be at least one version directory that we can inspect to
// attempt to determine the platform. If not, we skip autofill with file name.
// For now we allow multiple versions and only inspect the first verison
// directory to ensure it is valid. We can add more aggressive checks later.
const bool has_version = (version_dirs.size() != 0);
const auto version_path =
has_version ? JoinPath({model_path, *(version_dirs.begin())}) : "";
std::set<std::string> version_dir_content;
if (has_version) {
RETURN_IF_ERROR(GetDirectoryContents(version_path, &version_dir_content));
}
// If the model name is not given in the configuration, set if based
// on the model path.
if (config->name().empty()) {
config->set_name(model_name);
}
// Trying to fill the 'backend', 'default_model_filename' field.
// TensorFlow
// For TF backend, the platform is required
if (config->platform().empty()) {
// Check 'backend', 'default_model_filename', and the actual directory
// to determine the platform
if (config->backend().empty() ||
(config->backend() == kTensorFlowBackend)) {
if (config->default_model_filename() == kTensorFlowSavedModelFilename) {
config->set_platform(kTensorFlowSavedModelPlatform);
} else if (
config->default_model_filename() == kTensorFlowGraphDefFilename) {
config->set_platform(kTensorFlowGraphDefPlatform);
} else if (config->default_model_filename().empty() && has_version) {
bool is_dir = false;
if (version_dir_content.find(kTensorFlowSavedModelFilename) !=
version_dir_content.end()) {
RETURN_IF_ERROR(IsDirectory(
JoinPath({version_path, kTensorFlowSavedModelFilename}),
&is_dir));
if (is_dir) {
config->set_platform(kTensorFlowSavedModelPlatform);
}
}
if (version_dir_content.find(kTensorFlowGraphDefFilename) !=
version_dir_content.end()) {
RETURN_IF_ERROR(IsDirectory(
JoinPath({version_path, kTensorFlowGraphDefFilename}), &is_dir));
if (!is_dir) {
config->set_platform(kTensorFlowGraphDefPlatform);
}
}
}
}
}
// Fill 'backend' and 'default_model_filename' if missing
if ((config->platform() == kTensorFlowSavedModelPlatform) ||
(config->platform() == kTensorFlowGraphDefPlatform)) {
if (config->backend().empty()) {
config->set_backend(kTensorFlowBackend);
}
if (config->default_model_filename().empty()) {
if (config->platform() == kTensorFlowSavedModelPlatform) {
config->set_default_model_filename(kTensorFlowSavedModelFilename);
} else {
config->set_default_model_filename(kTensorFlowGraphDefFilename);
}
}
return Status::Success;
}
// TensorRT
if (config->backend().empty()) {
if ((config->platform() == kTensorRTPlanPlatform) ||
(config->default_model_filename() == kTensorRTPlanFilename)) {
config->set_backend(kTensorRTBackend);
} else if (
config->platform().empty() &&
config->default_model_filename().empty() && has_version) {
bool is_dir = false;
if (version_dir_content.find(kTensorRTPlanFilename) !=
version_dir_content.end()) {
RETURN_IF_ERROR(IsDirectory(
JoinPath({version_path, kTensorRTPlanFilename}), &is_dir));
if (!is_dir) {
config->set_backend(kTensorRTBackend);
}
}
}
}
if (config->backend() == kTensorRTBackend) {
if (config->platform().empty()) {
config->set_platform(kTensorRTPlanPlatform);
}
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kTensorRTPlanFilename);
}
return Status::Success;
}
// ONNXRuntime
if (config->backend().empty()) {
if ((config->platform() == kOnnxRuntimeOnnxPlatform) ||
(config->default_model_filename() == kOnnxRuntimeOnnxFilename)) {
config->set_backend(kOnnxRuntimeBackend);
} else if (
config->platform().empty() &&
config->default_model_filename().empty() && has_version) {
if (version_dir_content.find(kOnnxRuntimeOnnxFilename) !=
version_dir_content.end()) {
// ONNX model can be a file or a directory in the case of large model
config->set_backend(kOnnxRuntimeBackend);
}
}
}
if (config->backend() == kOnnxRuntimeBackend) {
if (config->platform().empty()) {
config->set_platform(kOnnxRuntimeOnnxPlatform);
}
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kOnnxRuntimeOnnxFilename);
}
return Status::Success;
}
// OpenVINO
if (config->backend().empty()) {
if (config->default_model_filename() == kOpenVINORuntimeOpenVINOFilename) {
config->set_backend(kOpenVINORuntimeBackend);
} else if (
config->platform().empty() &&
config->default_model_filename().empty() && has_version) {
if (version_dir_content.find(kOpenVINORuntimeOpenVINOFilename) !=
version_dir_content.end()) {
config->set_backend(kOpenVINORuntimeBackend);
}
}
}
if (config->backend() == kOpenVINORuntimeBackend) {
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kOpenVINORuntimeOpenVINOFilename);
}
return Status::Success;
}
// PyTorch (TorchScript, LibTorch)
if (config->backend().empty()) {
if ((config->platform() == kPyTorchLibTorchPlatform) ||
(config->default_model_filename() == kPyTorchLibTorchFilename)) {
config->set_backend(kPyTorchBackend);
} else if (
config->platform().empty() &&
config->default_model_filename().empty() && has_version) {
bool is_dir = false;
if (version_dir_content.find(kPyTorchLibTorchFilename) !=
version_dir_content.end()) {
RETURN_IF_ERROR(IsDirectory(
JoinPath({version_path, kPyTorchLibTorchFilename}), &is_dir));
if (!is_dir) {
config->set_backend(kPyTorchBackend);
}
}
}
}
if (config->backend() == kPyTorchBackend) {
if (config->platform().empty()) {
config->set_platform(kPyTorchLibTorchPlatform);
}
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kPyTorchLibTorchFilename);
}
return Status::Success;
}
// Python
if (config->backend().empty()) {
if (config->default_model_filename() == kPythonFilename) {
config->set_backend(kPythonBackend);
} else if (
config->platform().empty() &&
config->default_model_filename().empty() && has_version) {
if (version_dir_content.find(kPythonFilename) !=
version_dir_content.end()) {
config->set_backend(kPythonBackend);
}
}
}
if (config->backend() == kPythonBackend) {
if (config->default_model_filename().empty()) {
config->set_default_model_filename(kPythonFilename);
}
return Status::Success;
}
// Custom Backend
// For now, only do the narrowest case, where no info is given in the config.
if (config->backend().empty() && config->platform().empty() &&
config->default_model_filename().empty()) {
LOG_VERBOSE(1) << "Could not infer supported backend, so attempting "
"autofill of custom backend.";
// Since we lazily load the backends, we let the model tell us what backend
// to load. We must assume that if the model name conforms to the required
// shape, we parse the backend name out of the model file name. i.e.
// model.identity will set the backend to "identity".
const std::string delimiter = ".";
size_t pos = model_name.find(delimiter, 0);
if (pos == std::string::npos) {
return Status(
triton::common::Error::Code::INVALID_ARG,
("Invalid model name: Could not determine backend for model '" +
model_name +
"' with no backend in model configuration. Expected model name of "
"the form 'model.<backend_name>'."));
}
const std::string backend_name =
model_name.substr(pos + 1, std::string::npos);
config->set_backend(backend_name);
config->set_default_model_filename(
(std::string("model.") + backend_name).c_str());
return Status::Success;
}
return Status::Success;
}
Status
ValidateModelIOConfig(const inference::ModelConfig& config)
{
Status status;
for (const auto& io : config.input()) {
status = ValidateModelInput(io, config.max_batch_size(), config.platform());
if (!status.IsOk()) {
return Status(
status.StatusCode(), status.Message() + " for " + config.name());
}
}
for (const auto& io : config.output()) {
status =
ValidateModelOutput(io, config.max_batch_size(), config.platform());
if (!status.IsOk()) {
return Status(
status.StatusCode(), status.Message() + " for " + config.name());
}
}
status = ValidateBatchIO(config);
if (!status.IsOk()) {
return Status(
status.StatusCode(), status.Message() + " for " + config.name());
}
return Status::Success;
}
Status
ValidateBatchIO(const inference::ModelConfig& config)
{
std::set<std::string> input_names;
std::set<std::string> output_names;
for (const auto& io : config.input()) {
input_names.emplace(io.name());
}
for (const auto& io : config.output()) {
output_names.emplace(io.name());
}
for (const auto& batch_io : config.batch_input()) {
switch (batch_io.kind()) {
case inference::BatchInput::BATCH_ELEMENT_COUNT:
case inference::BatchInput::BATCH_ACCUMULATED_ELEMENT_COUNT:
case inference::BatchInput::BATCH_ACCUMULATED_ELEMENT_COUNT_WITH_ZERO:
case inference::BatchInput::BATCH_MAX_ELEMENT_COUNT_AS_SHAPE:
case inference::BatchInput::BATCH_ITEM_SHAPE:
case inference::BatchInput::BATCH_ITEM_SHAPE_FLATTEN: {
if (batch_io.source_input_size() != 1) {
return Status(
Status::Code::INVALID_ARG,
"batch input kind '" +
inference::BatchInput::Kind_Name(batch_io.kind()) +
"' expects 1 source input, got " +
std::to_string(batch_io.source_input_size()));
}
break;
}
default:
return Status(
Status::Code::INVALID_ARG,
"unknown batch input kind '" +
inference::BatchInput::Kind_Name(batch_io.kind()) + "'");
}
if ((batch_io.data_type() != inference::DataType::TYPE_INT32) &&
(batch_io.data_type() != inference::DataType::TYPE_FP32)) {
return Status(
Status::Code::INVALID_ARG,
"batch input data type must be TYPE_INT32 or TYPE_FP32");
}
for (const auto& source_name : batch_io.source_input()) {
if (input_names.find(source_name) == input_names.end()) {
return Status(
Status::Code::INVALID_ARG,
"unknown source input name '" + source_name + "'");
}
}
}
for (const auto& batch_io : config.batch_output()) {
switch (batch_io.kind()) {
case inference::BatchOutput::BATCH_SCATTER_WITH_INPUT_SHAPE: {
if (batch_io.source_input_size() != 1) {
return Status(
Status::Code::INVALID_ARG,
"batch output kind '" +
inference::BatchOutput::Kind_Name(batch_io.kind()) +
"' expects 1 source input, got " +
std::to_string(batch_io.source_input_size()));
}
break;
}
default:
return Status(
Status::Code::INVALID_ARG,
"unknown batch output kind '" +
inference::BatchOutput::Kind_Name(batch_io.kind()) + "'");
}
for (const auto& source_name : batch_io.source_input()) {
if (input_names.find(source_name) == input_names.end()) {
return Status(
Status::Code::INVALID_ARG,
"unknown source input name '" + source_name + "'");
}
}
std::set<std::string> target_names;
for (const auto& target_name : batch_io.target_name()) {
if (output_names.find(target_name) == output_names.end()) {
return Status(
Status::Code::INVALID_ARG,
"unknown target output name '" + target_name + "'");
}
if (target_names.emplace(target_name).second == false) {
return Status(
Status::Code::INVALID_ARG, "target output name '" + target_name +
"' can only be specified once");
}
}
}
return Status::Success;
}
Status
ValidateModelConfig(
const inference::ModelConfig& config, const double min_compute_capability)
{
if (config.name().empty()) {
return Status(
Status::Code::INVALID_ARG, "model configuration must specify 'name'");
}
if (config.backend().empty()) {
// Expect backend is not empty unless it is ensemble platform.
#ifdef TRITON_ENABLE_ENSEMBLE
if (config.platform() != kEnsemblePlatform)
#endif // TRITON_ENABLE_ENSEMBLE
return Status(
Status::Code::INVALID_ARG, "unexpected platform type '" +
config.platform() + "' for " +
config.name());
}
#ifdef TRITON_ENABLE_ENSEMBLE
else if (config.platform() == kEnsemblePlatform) {
return Status(
Status::Code::INVALID_ARG,
"Ensemble model '" + config.name() + "' must have platform type '" +
config.platform() + "' and empty backend type");
}
#endif // TRITON_ENABLE_ENSEMBLE
if (config.platform().empty() && config.backend().empty()) {
return Status(
Status::Code::INVALID_ARG,
"must specify 'platform' or 'backend' for '" + config.name() + "'");
}
// Ensure both platform and backend are referring to known backend,
// or both referring to unknown backend for user-provided backend.
if (GetBackendTypeFromPlatform(config.platform()) !=
GetBackendType(config.backend())) {
return Status(
Status::Code::INVALID_ARG,
"unexpected 'platform' and 'backend' pair, got:" + config.platform() +
", " + config.backend());
}
if (config.max_batch_size() < 0) {
return Status(
Status::Code::INVALID_ARG,
"'max_batch_size' must be non-negative value for " + config.name());
}
if (!config.has_version_policy()) {
return Status(
Status::Code::INVALID_ARG,
"must specify 'version policy' for " + config.name());
}
// If dynamic batching is specified make sure the preferred batch
// sizes are positive and don't exceed maximum batch size.
if (config.has_dynamic_batching()) {
for (const auto size : config.dynamic_batching().preferred_batch_size()) {
if (size <= 0) {
return Status(
Status::Code::INVALID_ARG,
"dynamic batching preferred size must be positive for " +
config.name());
}
if (size > config.max_batch_size()) {
return Status(
Status::Code::INVALID_ARG,
"dynamic batching preferred size must be <= max batch size for " +
config.name());
}
}
// Priority queue is specified
const auto priority_levels = config.dynamic_batching().priority_levels();
if (priority_levels != 0) {
if ((config.dynamic_batching().default_priority_level() == 0) ||
(config.dynamic_batching().default_priority_level() >
priority_levels)) {
return Status(
Status::Code::INVALID_ARG,
"default priority level must be in range [1, " +
std::to_string(priority_levels) + "] for " + config.name());
}
for (const auto& queue_policy :
config.dynamic_batching().priority_queue_policy()) {
if ((queue_policy.first == 0) ||
(queue_policy.first > priority_levels)) {
return Status(
Status::Code::INVALID_ARG,
"priority queue policy must have priority level in range [1, " +
std::to_string(priority_levels) + "] for " + config.name());
}
}
}
// preserve ordering option will conflict with priorities and delay policy
if (config.dynamic_batching().preserve_ordering()) {
if (priority_levels > 1) {
return Status(
Status::Code::INVALID_ARG,
"Only one priority level is allowed when 'preserve_ordering' is "
"true for " +
config.name());
}
const auto& default_policy =
config.dynamic_batching().default_queue_policy();
if ((default_policy.default_timeout_microseconds() != 0) &&
(default_policy.timeout_action() ==
inference::ModelQueuePolicy::DELAY)) {
return Status(
Status::Code::INVALID_ARG,
"Queue policy can not have DELAY as timeout action when "
"'preserve_ordering' is true for " +
config.name());
}
// Also need to check policy in 'priority_queue_policy'
// for single priority case
for (const auto& policy :
config.dynamic_batching().priority_queue_policy()) {
if ((policy.second.default_timeout_microseconds() != 0) &&
(policy.second.timeout_action() ==
inference::ModelQueuePolicy::DELAY)) {
return Status(
Status::Code::INVALID_ARG,
"Queue policy can not have DELAY as timeout action when "
"'preserve_ordering' is true for " +
config.name());
}
}
}
}
// If sequence batching is specified make sure the control is
// specified correctly.
if (config.has_sequence_batching()) {
const auto& batcher = config.sequence_batching();
// Check boolean controls...
std::string tensor_name;
RETURN_IF_ERROR(GetBooleanSequenceControlProperties(
batcher, config.name(),
inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_START,
false /* required */, &tensor_name, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr));
RETURN_IF_ERROR(GetBooleanSequenceControlProperties(
batcher, config.name(),
inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_END,
false /* required */, &tensor_name, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr));
RETURN_IF_ERROR(GetBooleanSequenceControlProperties(
batcher, config.name(),
inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_READY,
false /* required */, &tensor_name, nullptr, nullptr, nullptr, nullptr,
nullptr, nullptr, nullptr));
// Check CORRID control and make sure it is one of the allowed types.
inference::DataType tensor_datatype;
RETURN_IF_ERROR(GetTypedSequenceControlProperties(
batcher, config.name(),
inference::ModelSequenceBatching::Control::CONTROL_SEQUENCE_CORRID,
false /* required */, &tensor_name, &tensor_datatype));
if (!tensor_name.empty()) {
if ((tensor_datatype != inference::DataType::TYPE_UINT64) &&
(tensor_datatype != inference::DataType::TYPE_INT64) &&
(tensor_datatype != inference::DataType::TYPE_UINT32) &&
(tensor_datatype != inference::DataType::TYPE_INT32) &&
(tensor_datatype != inference::DataType::TYPE_STRING)) {
return Status(
Status::Code::INVALID_ARG,
"unexpected data type for control " +
inference::ModelSequenceBatching_Control_Kind_Name(
inference::ModelSequenceBatching::Control::
CONTROL_SEQUENCE_CORRID) +
" for " + config.name() +
". Allowed data types are TYPE_UINT64, TYPE_INT64, "
"TYPE_UINT32, "
"TYPE_INT32 and TYPE_STRING");
}
}
// If oldest-first strategy is enabled make sure the preferred
// batch sizes are positive and don't exceed maximum batch size.
if (config.sequence_batching().has_oldest()) {
for (const auto size :
config.sequence_batching().oldest().preferred_batch_size()) {
if (size <= 0) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching preferred batch size must be positive for " +
config.name());
}
if (size > config.max_batch_size()) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching preferred batch size must be <= max batch "
"size for " +
config.name());
}
}
}
// If direct strategy is enabled make sure the minimum slot utilization is
// in range (0.0, 1.0]
if (config.sequence_batching().has_direct()) {
if ((config.sequence_batching().direct().minimum_slot_utilization() <
0.0) ||
(config.sequence_batching().direct().minimum_slot_utilization() >
1.0)) {
return Status(
Status::Code::INVALID_ARG,
"sequence batching minimum slot utilization must be in range "
"(0.0, 1.0] for " +
config.name());
}
}
}
// If ensemble scheduling is specified, validate it. Otherwise,
// must validate platform and instance_group
if (config.has_ensemble_scheduling()) {
#ifdef TRITON_ENABLE_ENSEMBLE
RETURN_IF_ERROR(ValidateEnsembleSchedulingConfig(config));
#else
return Status(
Status::Code::INVALID_ARG, "ensemble scheduling not supported");
#endif // TRITON_ENABLE_ENSEMBLE
}
#ifdef TRITON_ENABLE_ENSEMBLE
else if (config.platform() == kEnsemblePlatform) {
return Status(
Status::Code::INVALID_ARG,
"ensemble scheduling must be set for ensemble " + config.name() +
" whose platform is " + kEnsemblePlatform);
}
#endif // TRITON_ENABLE_ENSEMBLE
// FIXME: DLIS-3916 - Response Cache does not yet support decoupled models
if (config.model_transaction_policy().decoupled() &&
config.response_cache().enable()) {
return Status(
Status::Code::INVALID_ARG,
"Response Cache does not currently support model " + config.name() +
" with 'decoupled' transaction policy. Please disable the response"
" cache.");
}
return Status::Success;
}
Status
ValidateInstanceGroup(
const inference::ModelConfig& config, const double min_compute_capability)
{
// Instance group setting doesn't apply to ensemble
if (config.has_ensemble_scheduling()) {
return Status::Success;
}
if (config.instance_group().size() == 0) {
return Status(
Status::Code::INVALID_ARG,
"must specify one or more 'instance group's for " + config.name());
}
// Make sure KIND_GPU instance group specifies at least one GPU and
// doesn't specify a non-existent GPU. Make sure non-KIND_GPU does
// not specify any GPUs.
#ifdef TRITON_ENABLE_GPU
std::set<int> supported_gpus;
Status status = GetSupportedGPUs(&supported_gpus, min_compute_capability);
if (!status.IsOk()) {
return status;
}
#endif // TRITON_ENABLE_GPU
for (const auto& group : config.instance_group()) {
if (group.kind() == inference::ModelInstanceGroup::KIND_MODEL) {
if (group.gpus().size() > 0) {
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" has kind KIND_MODEL but specifies one or more GPUs");
}
} else if (group.kind() == inference::ModelInstanceGroup::KIND_GPU) {
#if !defined(TRITON_ENABLE_GPU) && !defined(TRITON_ENABLE_MALI_GPU)
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" has kind KIND_GPU but server does not support GPUs");
#elif defined(TRITON_ENABLE_GPU)
if (group.gpus().size() == 0) {
if (supported_gpus.size() == 0) {
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" has kind KIND_GPU but no GPUs are available");
} else {
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" has kind KIND_GPU but specifies no GPUs");
}
}
for (const int32_t gid : group.gpus()) {
if (supported_gpus.find(gid) == supported_gpus.end()) {
std::string supported_gpus_str;
for (const auto& cc : supported_gpus) {
if (!supported_gpus_str.empty()) {
supported_gpus_str += ", ";
}
supported_gpus_str += std::to_string(cc);
}
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" specifies invalid or unsupported gpu id " +
std::to_string(gid) +
". GPUs with at least the minimum required CUDA compute "
"compatibility of " +
std::to_string(min_compute_capability) +
" are: " + supported_gpus_str);
}
}
#endif // ! TRITON_ENABLE_GPU && ! TRITON_ENABLE_MALI_GPU
} else if (group.kind() == inference::ModelInstanceGroup::KIND_CPU) {
if (group.gpus().size() > 0) {
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" has kind KIND_CPU but specifies one or more GPUs");
}
} else {
return Status(
Status::Code::INTERNAL, "instance group " + group.name() +
" of model " + config.name() +
" has unexpected kind KIND_AUTO");
}
if ((config.platform() != kTensorRTPlanPlatform) &&
!group.profile().empty()) {
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" and platform " + config.platform() +
"specifies profile field which is only supported for "
"TensorRT models");
} else if (!group.profile().empty()) {
for (const auto& profile : group.profile()) {
int profile_index;
RETURN_IF_ERROR(GetProfileIndex(profile, &profile_index));
if (profile_index < 0) {
return Status(
Status::Code::INVALID_ARG,
"instance group " + group.name() + " of model " + config.name() +
" and platform " + config.platform() +
" specifies invalid profile " + profile +
". The field should contain the string representation of a "
"non-negative integer.");
}
}
}
}
return Status::Success;
}
Status
ValidateModelInput(
const inference::ModelInput& io, int32_t max_batch_size,
const std::string& platform)
{
RETURN_IF_ERROR(ValidateIOShape(io, max_batch_size, "model input "));
if (((io.format() == inference::ModelInput::FORMAT_NHWC) ||
(io.format() == inference::ModelInput::FORMAT_NCHW)) &&
(io.dims_size() != 3)) {
return Status(
Status::Code::INVALID_ARG, "model input NHWC/NCHW require 3 dims");
}
if ((platform != kTensorRTPlanPlatform) && io.is_shape_tensor()) {
return Status(
Status::Code::INVALID_ARG,
"shape tensors are only supported for TensorRT platform");
}
return Status::Success;
}
Status
CheckAllowedModelInput(
const inference::ModelInput& io, const std::set<std::string>& allowed)
{
if (allowed.find(io.name()) == allowed.end()) {
std::string astr;
for (const auto& a : allowed) {
if (!astr.empty()) {
astr.append(", ");
}
astr.append(a);
}
return Status(
Status::Code::INVALID_ARG, "unexpected inference input '" + io.name() +
"', allowed inputs are: " + astr);
}
return Status::Success;
}
Status
ValidateModelOutput(
const inference::ModelOutput& io, int32_t max_batch_size,
const std::string& platform)
{
RETURN_IF_ERROR(ValidateIOShape(io, max_batch_size, "model output "));
if ((platform != kTensorRTPlanPlatform) && io.is_shape_tensor()) {
return Status(
Status::Code::INVALID_ARG,
"shape tensors are only supported for TensorRT platform");
}
return Status::Success;
}
Status
CheckAllowedModelOutput(
const inference::ModelOutput& io, const std::set<std::string>& allowed)
{
if (allowed.find(io.name()) == allowed.end()) {
std::string astr;
for (const auto& a : allowed) {
if (!astr.empty()) {
astr.append(", ");
}
astr.append(a);
}
return Status(
Status::Code::INVALID_ARG, "unexpected inference output '" + io.name() +
"', allowed outputs are: " + astr);
}
return Status::Success;
}
Status
ParseBoolParameter(
const std::string& key, std::string value, bool* parsed_value)
{
std::transform(
value.begin(), value.end(), value.begin(),
[](unsigned char c) { return std::tolower(c); });
if ((value == "true") || (value == "1")) {
*parsed_value = true;
} else if ((value == "false") || (value == "0")) {
*parsed_value = false;
} else {
return Status(
Status::Code::INVALID_ARG,
"failed to convert " + key + " '" + value + "' to boolean value");
}
return Status::Success;
}
Status
ParseLongLongParameter(
const std::string& key, const std::string& value, int64_t* parsed_value)
{
try {
*parsed_value = std::stoll(value);
}
catch (const std::invalid_argument& ia) {
return Status(
Status::Code::INVALID_ARG,
"failed to convert " + key + " '" + value + "' to integral number");
}
return Status::Success;
}
Status
GetProfileIndex(const std::string& profile_name, int* profile_index)
{
if (profile_name.empty()) {
return Status(Status::Code::INVALID_ARG, "profile name must not be empty");
}
try {
*profile_index = stoi(profile_name);
}
catch (const std::invalid_argument& ia) {
return Status(
Status::Code::INVALID_ARG,
"unable to parse '" + profile_name + "': " + ia.what());
}
return Status::Success;
}
namespace {
Status
CollectInt64Fields(
google::protobuf::Message* message, const std::string& prefix,
std::set<std::string>* int64_fields)
{
const google::protobuf::Descriptor* desc = message->GetDescriptor();
const google::protobuf::Reflection* refl = message->GetReflection();
for (int i = 0; i < desc->field_count(); ++i) {
const google::protobuf::FieldDescriptor* field = desc->field(i);
const std::string fullname = prefix + "::" + field->name();
switch (field->type()) {
case google::protobuf::FieldDescriptor::TYPE_MESSAGE: {
if (field->is_repeated()) {
int rsize = refl->FieldSize(*message, field);
if (rsize == 0) {
refl->AddMessage(message, field);
}
rsize = refl->FieldSize(*message, field);
for (int r = 0; r < rsize; ++r) {
RETURN_IF_ERROR(CollectInt64Fields(
refl->MutableRepeatedMessage(message, field, r), fullname,
int64_fields));
}
} else {
RETURN_IF_ERROR(CollectInt64Fields(
refl->MutableMessage(message, field), fullname, int64_fields));
}
} break;
case google::protobuf::FieldDescriptor::TYPE_INT64:
case google::protobuf::FieldDescriptor::TYPE_UINT64:
case google::protobuf::FieldDescriptor::TYPE_SINT64:
case google::protobuf::FieldDescriptor::TYPE_FIXED64:
case google::protobuf::FieldDescriptor::TYPE_SFIXED64:
int64_fields->insert(fullname);
break;
default:
break;
}
}
return Status::Success;
}
Status
ValidateModelConfigInt64()
{
// Must initialize a dummy ModelConfig so that all fields are
// visited.
inference::ModelConfig config;
std::set<std::string> int64_fields;
RETURN_IF_ERROR(CollectInt64Fields(&config, "ModelConfig", &int64_fields));
LOG_VERBOSE(1) << "ModelConfig 64-bit fields:";
for (const auto& f : int64_fields) {
LOG_VERBOSE(1) << "\t" << f;
}
// We expect to find exactly the following fields. If we get an
// error from this code ModelConfig has added or removed a 64-bit
// field and we need to adjust here and in ModelConfigToJson below.
std::set<std::string> expected{
"ModelConfig::input::dims",
"ModelConfig::input::reshape::shape",
"ModelConfig::output::dims",
"ModelConfig::output::reshape::shape",
"ModelConfig::version_policy::specific::versions",
"ModelConfig::dynamic_batching::max_queue_delay_microseconds",
"ModelConfig::dynamic_batching::default_queue_policy::default_timeout_"
"microseconds",
"ModelConfig::dynamic_batching::priority_queue_policy::value::default_"
"timeout_microseconds",
"ModelConfig::sequence_batching::direct::max_queue_delay_microseconds",
"ModelConfig::sequence_batching::state::dims",
"ModelConfig::sequence_batching::state::initial_state::dims",
"ModelConfig::sequence_batching::oldest::max_queue_delay_microseconds",
"ModelConfig::sequence_batching::max_sequence_idle_microseconds",
"ModelConfig::ensemble_scheduling::step::model_version",
"ModelConfig::model_warmup::inputs::value::dims",
"ModelConfig::optimization::cuda::graph_spec::input::value::dim",
"ModelConfig::optimization::cuda::graph_spec::graph_lower_bound::input::"
"value::dim",
"ModelConfig::instance_group::secondary_devices::device_id"};
if (int64_fields != expected) {
return Status(
Status::Code::INTERNAL, "ModelConfig 64-bit field needs update");
}
return Status::Success;
}
Status
FixInt(
triton::common::TritonJson::Value& document,
triton::common::TritonJson::Value& io, const std::string& name)
{
triton::common::TritonJson::Value str_value;
if (!io.Find(name.c_str(), &str_value)) {
return Status::Success;
}
std::string str;
RETURN_IF_ERROR(str_value.AsString(&str));
int64_t d;
try {
d = std::atoll(str.c_str());
}
catch (...) {
return Status(
Status::Code::INTERNAL,
(std::string("unable to convert '") + str + "' to integer"));
}
str_value.SetInt(d);
return Status::Success;
}
Status
FixIntArray(
triton::common::TritonJson::Value& document,
triton::common::TritonJson::Value& io, const std::string& name)
{
triton::common::TritonJson::Value fixed_shape_array(
document, triton::common::TritonJson::ValueType::ARRAY);
if (!io.Find(name.c_str())) {
return Status::Success;
}
triton::common::TritonJson::Value shape_array;
RETURN_IF_ERROR(io.MemberAsArray(name.c_str(), &shape_array));
for (size_t i = 0; i < shape_array.ArraySize(); ++i) {
std::string str;
RETURN_IF_ERROR(shape_array.IndexAsString(i, &str));
int64_t d;
try {
d = std::atoll(str.c_str());
}
catch (...) {
return Status(
Status::Code::INTERNAL,
(std::string("unable to convert '") + str + "' to integer"));
}
RETURN_IF_ERROR(fixed_shape_array.AppendInt(d));
}
shape_array.Swap(fixed_shape_array);
fixed_shape_array.Release();
return Status::Success;
}
Status
FixObjectArray(
triton::common::TritonJson::Value& document,
triton::common::TritonJson::Value& arr, const std::string& name)
{
for (size_t i = 0; i < arr.ArraySize(); ++i) {
triton::common::TritonJson::Value obj;
RETURN_IF_ERROR(arr.IndexAsObject(i, &obj));
RETURN_IF_ERROR(FixInt(document, obj, name));
}
return Status::Success;
}
} // namespace
Status
ModelConfigToJson(
const inference::ModelConfig& config, const uint32_t config_version,
std::string* json_str)
{
// Currently only support 'config_version' 1, which is the json
// representation of the ModelConfig protobuf with the int64 fields
// fixes to be actual numbers instead of the string madness done by
// protobuf.
if (config_version != 1) {
return Status(
Status::Code::INVALID_ARG,
std::string("model configuration version ") +
std::to_string(config_version) +
" not supported, supported versions are: 1");
}
// Config will have 0 byte size if all fields are with default value,
// in other word the config is empty.
if (config.ByteSizeLong() == 0) {
json_str->clear();
return Status::Success;
}
std::string config_json_str;
::google::protobuf::util::JsonPrintOptions options;
options.preserve_proto_field_names = true;
options.always_print_primitive_fields = true;
::google::protobuf::util::MessageToJsonString(
config, &config_json_str, options);
// We need to verify that every field 64-bit field in the
// ModelConfig protobuf is being handled. We hardcode the known
// fields and check just once to make sure everything has been
// handled. We could have this check in a separately compiled CI
// test but it is convenient to keep it here close to the code below
// that actually fixes the 64-bit fields.
{
static std::once_flag fonce;
Status status = Status::Success;
std::call_once(fonce, [&status] { status = ValidateModelConfigInt64(); });
RETURN_IF_ERROR(status);
}
// In the json produced by protobuf, int64 and uint64 values are
// represented as strings. Protobuf doesn't provide an option to
// disable this (sigh) so we need to fix it up here as we want the
// json representation of the config to be reasonable json...
triton::common::TritonJson::Value config_json;
config_json.Parse(config_json_str);
// Fix input::dims, input::reshape::shape, output::dims,
// output::reshape::shape
for (std::string name : {"input", "output"}) {
triton::common::TritonJson::Value ios;
RETURN_IF_ERROR(config_json.MemberAsArray(name.c_str(), &ios));
for (size_t i = 0; i < ios.ArraySize(); ++i) {
triton::common::TritonJson::Value io;
RETURN_IF_ERROR(ios.IndexAsObject(i, &io));
RETURN_IF_ERROR(FixIntArray(config_json, io, "dims"));
triton::common::TritonJson::Value reshape;
if (io.Find("reshape", &reshape)) {
RETURN_IF_ERROR(FixIntArray(config_json, reshape, "shape"));
}
}
}
// Fix version_policy::specific::versions
{
triton::common::TritonJson::Value vp;
if (config_json.Find("version_policy", &vp)) {
triton::common::TritonJson::Value specific;
if (vp.Find("specific", &specific)) {
RETURN_IF_ERROR(FixIntArray(config_json, specific, "versions"));
}
}
}
// Fix dynamic_batching::max_queue_delay_microseconds,
// dynamic_batching::default_queue_policy::default_timeout_microseconds,
// dynamic_batching::priority_queue_policy::value::default_timeout_microseconds
{
triton::common::TritonJson::Value db;
if (config_json.Find("dynamic_batching", &db)) {
RETURN_IF_ERROR(FixInt(config_json, db, "max_queue_delay_microseconds"));
triton::common::TritonJson::Value dqp;
if (db.Find("default_queue_policy", &dqp)) {
RETURN_IF_ERROR(
FixInt(config_json, dqp, "default_timeout_microseconds"));
}
triton::common::TritonJson::Value pqp;
if (db.Find("priority_queue_policy", &pqp)) {
// Iterate over each member in 'pqp' and fix...
std::vector<std::string> members;
RETURN_IF_ERROR(pqp.Members(&members));
for (const auto& m : members) {
triton::common::TritonJson::Value el;
RETURN_IF_ERROR(pqp.MemberAsObject(m.c_str(), &el));
RETURN_IF_ERROR(
FixInt(config_json, el, "default_timeout_microseconds"));
}
}
}
}
// Fix sequence_batching::oldest::max_queue_delay_microseconds,
// sequence_batching::direct::max_queue_delay_microseconds,
// sequence_batching::max_sequence_idle_microseconds
{
triton::common::TritonJson::Value sb;
if (config_json.Find("sequence_batching", &sb)) {
RETURN_IF_ERROR(
FixInt(config_json, sb, "max_sequence_idle_microseconds"));
triton::common::TritonJson::Value oldest;
if (sb.Find("oldest", &oldest)) {
RETURN_IF_ERROR(
FixInt(config_json, oldest, "max_queue_delay_microseconds"));
}
triton::common::TritonJson::Value direct;
if (sb.Find("direct", &direct)) {
RETURN_IF_ERROR(
FixInt(config_json, direct, "max_queue_delay_microseconds"));
}
triton::common::TritonJson::Value states;
if (sb.Find("state", &states)) {
for (size_t i = 0; i < states.ArraySize(); ++i) {
triton::common::TritonJson::Value state;
RETURN_IF_ERROR(states.IndexAsObject(i, &state));
RETURN_IF_ERROR(FixIntArray(config_json, state, "dims"));
triton::common::TritonJson::Value initial_state;
if (sb.Find("initial_state", &initial_state)) {
RETURN_IF_ERROR(FixIntArray(config_json, initial_state, "dims"));
}
}
}
}
}
// Fix ensemble_scheduling::step::model_version.
{
triton::common::TritonJson::Value ens;
if (config_json.Find("ensemble_scheduling", &ens)) {
triton::common::TritonJson::Value step;
if (ens.Find("step", &step)) {
RETURN_IF_ERROR(FixObjectArray(config_json, step, "model_version"));
}
}
}
// Fix model_warmup::inputs::value::dims.
{
triton::common::TritonJson::Value warmups;
if (config_json.Find("model_warmup", &warmups)) {
for (size_t i = 0; i < warmups.ArraySize(); ++i) {
triton::common::TritonJson::Value warmup;
RETURN_IF_ERROR(warmups.IndexAsObject(i, &warmup));
triton::common::TritonJson::Value inputs;
if (warmup.Find("inputs", &inputs)) {
std::vector<std::string> members;
RETURN_IF_ERROR(inputs.Members(&members));
for (const auto& m : members) {
triton::common::TritonJson::Value input;
RETURN_IF_ERROR(inputs.MemberAsObject(m.c_str(), &input));
RETURN_IF_ERROR(FixIntArray(config_json, input, "dims"));
}
}
}
}
}
// Convert fixed json back the string...
triton::common::TritonJson::WriteBuffer buffer;
RETURN_IF_ERROR(config_json.Write(&buffer));
*json_str = std::move(buffer.MutableContents());
return Status::Success;
}
Status
JsonToModelConfig(
const std::string& json_config, const uint32_t config_version,
inference::ModelConfig* protobuf_config)
{
// Currently only support 'config_version' 1, which is the json
// representation of the ModelConfig protobuf matches the representation in
// ModelConfigToJson().
if (config_version != 1) {
return Status(
Status::Code::INVALID_ARG,
std::string("model configuration version ") +
std::to_string(config_version) +
" not supported, supported versions are: 1");
}
::google::protobuf::util::JsonParseOptions options;
options.case_insensitive_enum_parsing = true;
options.ignore_unknown_fields = false;
auto err = ::google::protobuf::util::JsonStringToMessage(
json_config, protobuf_config, options);
if (!err.ok()) {
return Status(Status::Code::INVALID_ARG, std::string(err.message()));
}
return Status::Success;
}
BackendType
GetBackendTypeFromPlatform(const std::string& platform_name)
{
if ((platform_name == kTensorFlowGraphDefPlatform) ||
(platform_name == kTensorFlowSavedModelPlatform)) {
return BackendType::BACKEND_TYPE_TENSORFLOW;
}
if (platform_name == kTensorRTPlanPlatform) {
return BackendType::BACKEND_TYPE_TENSORRT;
}
if (platform_name == kOnnxRuntimeOnnxPlatform) {
return BackendType::BACKEND_TYPE_ONNXRUNTIME;
}
if (platform_name == kPyTorchLibTorchPlatform) {
return BackendType::BACKEND_TYPE_PYTORCH;
}
return BackendType::BACKEND_TYPE_UNKNOWN;
}
/// Get the BackendType value for a backend name.
/// \param backend_name The backend name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType
GetBackendType(const std::string& backend_name)
{
if (backend_name == kTensorFlowBackend) {
return BackendType::BACKEND_TYPE_TENSORFLOW;
}
if (backend_name == kTensorRTBackend) {
return BackendType::BACKEND_TYPE_TENSORRT;
}
if (backend_name == kOnnxRuntimeBackend) {
return BackendType::BACKEND_TYPE_ONNXRUNTIME;
}
if (backend_name == kPyTorchBackend) {
return BackendType::BACKEND_TYPE_PYTORCH;
}
return BackendType::BACKEND_TYPE_UNKNOWN;
}
TRITONSERVER_DataType
DataTypeToTriton(const inference::DataType dtype)
{
switch (dtype) {
case inference::DataType::TYPE_BOOL:
return TRITONSERVER_TYPE_BOOL;
case inference::DataType::TYPE_UINT8:
return TRITONSERVER_TYPE_UINT8;
case inference::DataType::TYPE_UINT16:
return TRITONSERVER_TYPE_UINT16;
case inference::DataType::TYPE_UINT32:
return TRITONSERVER_TYPE_UINT32;
case inference::DataType::TYPE_UINT64:
return TRITONSERVER_TYPE_UINT64;
case inference::DataType::TYPE_INT8:
return TRITONSERVER_TYPE_INT8;
case inference::DataType::TYPE_INT16:
return TRITONSERVER_TYPE_INT16;
case inference::DataType::TYPE_INT32:
return TRITONSERVER_TYPE_INT32;
case inference::DataType::TYPE_INT64:
return TRITONSERVER_TYPE_INT64;
case inference::DataType::TYPE_FP16:
return TRITONSERVER_TYPE_FP16;
case inference::DataType::TYPE_FP32:
return TRITONSERVER_TYPE_FP32;
case inference::DataType::TYPE_FP64:
return TRITONSERVER_TYPE_FP64;
case inference::DataType::TYPE_STRING:
return TRITONSERVER_TYPE_BYTES;
case inference::DataType::TYPE_BF16:
return TRITONSERVER_TYPE_BF16;
default:
break;
}
return TRITONSERVER_TYPE_INVALID;
}
inference::DataType
TritonToDataType(const TRITONSERVER_DataType dtype)
{
switch (dtype) {
case TRITONSERVER_TYPE_BOOL:
return inference::DataType::TYPE_BOOL;
case TRITONSERVER_TYPE_UINT8:
return inference::DataType::TYPE_UINT8;
case TRITONSERVER_TYPE_UINT16:
return inference::DataType::TYPE_UINT16;
case TRITONSERVER_TYPE_UINT32:
return inference::DataType::TYPE_UINT32;
case TRITONSERVER_TYPE_UINT64:
return inference::DataType::TYPE_UINT64;
case TRITONSERVER_TYPE_INT8:
return inference::DataType::TYPE_INT8;
case TRITONSERVER_TYPE_INT16:
return inference::DataType::TYPE_INT16;
case TRITONSERVER_TYPE_INT32:
return inference::DataType::TYPE_INT32;
case TRITONSERVER_TYPE_INT64:
return inference::DataType::TYPE_INT64;
case TRITONSERVER_TYPE_FP16:
return inference::DataType::TYPE_FP16;
case TRITONSERVER_TYPE_FP32:
return inference::DataType::TYPE_FP32;
case TRITONSERVER_TYPE_FP64:
return inference::DataType::TYPE_FP64;
case TRITONSERVER_TYPE_BYTES:
return inference::DataType::TYPE_STRING;
case TRITONSERVER_TYPE_BF16:
return inference::DataType::TYPE_BF16;
default:
break;
}
return inference::DataType::TYPE_INVALID;
}
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "model_config.pb.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
#include "filesystem.h"
namespace triton { namespace core {
/// Enumeration for the different backend types.
enum BackendType {
BACKEND_TYPE_UNKNOWN = 0,
BACKEND_TYPE_TENSORRT = 1,
BACKEND_TYPE_TENSORFLOW = 2,
BACKEND_TYPE_ONNXRUNTIME = 3,
BACKEND_TYPE_PYTORCH = 4
};
// Get version of a model from the path containing the model
/// definition file.
/// \param path The path to the model definition file.
/// \param version Returns the version.
/// \return The error status.
Status GetModelVersionFromPath(const std::string& path, int64_t* version);
/// Get the tensor name, false value, and true value for a boolean
/// sequence batcher control kind. If 'required' is true then must
/// find a tensor for the control. If 'required' is false, return
/// 'tensor_name' as empty-string if the control is not mapped to any
/// tensor.
Status GetBooleanSequenceControlProperties(
const inference::ModelSequenceBatching& batcher,
const std::string& model_name,
const inference::ModelSequenceBatching::Control::Kind control_kind,
const bool required, std::string* tensor_name,
inference::DataType* tensor_datatype, float* fp32_false_value,
float* fp32_true_value, int32_t* int32_false_value,
int32_t* int32_true_value, bool* bool_false_value, bool* bool_true_value);
/// Get the tensor name and datatype for a non-boolean sequence
/// batcher control kind. If 'required' is true then must find a
/// tensor for the control. If 'required' is false, return
/// 'tensor_name' as empty-string if the control is not mapped to any
/// tensor. 'tensor_datatype' returns the required datatype for the
/// control.
Status GetTypedSequenceControlProperties(
const inference::ModelSequenceBatching& batcher,
const std::string& model_name,
const inference::ModelSequenceBatching::Control::Kind control_kind,
const bool required, std::string* tensor_name,
inference::DataType* tensor_datatype);
/// Read a ModelConfig and normalize it as expected by model backends.
/// \param path The full-path to the directory containing the
/// model configuration.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \param config Returns the normalized model configuration.
/// \return The error status.
Status GetNormalizedModelConfig(
const std::string& model_name, const std::string& path,
const double min_compute_capability, inference::ModelConfig* config);
/// Auto-complete backend related fields (platform, backend and default model
/// filename) if not set, note that only Triton recognized backends will be
/// checked.
/// \param model_name The name of the model.
/// \param model_path The full-path to the directory containing the
/// model configuration.
/// \param config Returns the auto-completed model configuration.
/// \return The error status.
Status AutoCompleteBackendFields(
const std::string& model_name, const std::string& model_path,
inference::ModelConfig* config);
/// Detects and adds missing fields in the model configuration.
/// \param min_compute_capability The minimum supported CUDA compute
/// capability.
/// \param config The model configuration
/// \return The error status
Status NormalizeModelConfig(
const double min_compute_capability, inference::ModelConfig* config);
/// [FIXME] better formalize config normalization / validation
/// Detects and adds missing fields in instance group setting.
/// \param min_compute_capability The minimum supported CUDA compute
/// capability.
/// \param config The model configuration
/// \return The error status
Status NormalizeInstanceGroup(
const double min_compute_capability,
const std::vector<inference::ModelInstanceGroup>& preferred_groups,
inference::ModelConfig* config);
/// [FIXME] Remove once a more permanent solution is implemented (DLIS-4211)
/// Localize EXECUTION_ENV_PATH in python backend.
/// \param model_path The full-path to the directory containing the model
/// configuration, before localization.
/// \param config The model configuration
/// \param localized_model_dir The localized model directory
/// \return The error status
Status LocalizePythonBackendExecutionEnvironmentPath(
const std::string& model_path, inference::ModelConfig* config,
std::shared_ptr<LocalizedPath>* localized_model_dir);
/// Auto-complete the instance count based on instance kind and backend name.
/// \param group The instance group to set the count for.
/// \param backend The backend name to check against.
/// \return The error status.
Status SetDefaultInstanceCount(
inference::ModelInstanceGroup* group, const std::string& backend);
/// Validate that a model is specified correctly, except for model inputs
/// and outputs. ValidateModelIOConfig() should be called to
/// validate model inputs and outputs.
/// \param config The model configuration to validate.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status ValidateModelConfig(
const inference::ModelConfig& config, const double min_compute_capability);
/// [FIXME] better formalize config normalization / validation
/// Validate instance group setting.
/// \param config The model configuration to validate.
/// \param min_compute_capability The minimum support CUDA compute
/// capability.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status ValidateInstanceGroup(
const inference::ModelConfig& config, const double min_compute_capability);
/// Validate that a model inputs and outputs are specified correctly.
/// \param config The model configuration to validate.
/// \return The error status. A non-OK status indicates the configuration
/// is not valid.
Status ValidateModelIOConfig(const inference::ModelConfig& config);
/// Validate that input is specified correctly in a model
/// configuration.
/// \param io The model input.
/// \param max_batch_size The max batch size specified in model configuration.
/// \param platform The platform name
/// \return The error status. A non-OK status indicates the input
/// is not valid.
Status ValidateModelInput(
const inference::ModelInput& io, int32_t max_batch_size,
const std::string& platform);
/// Validate that an input matches one of the allowed input names.
/// \param io The model input.
/// \param allowed The set of allowed input names.
/// \return The error status. A non-OK status indicates the input
/// is not valid.
Status CheckAllowedModelInput(
const inference::ModelInput& io, const std::set<std::string>& allowed);
/// Validate that an output is specified correctly in a model
/// configuration.
/// \param io The model output.
/// \param max_batch_size The max batch size specified in model configuration.
/// \param platform The platform name
/// \return The error status. A non-OK status indicates the output
/// is not valid.
Status ValidateModelOutput(
const inference::ModelOutput& io, int32_t max_batch_size,
const std::string& platform);
/// Validate that an output matches one of the allowed output names.
/// \param io The model output.
/// \param allowed The set of allowed output names.
/// \return The error status. A non-OK status indicates the output
/// is not valid.
Status CheckAllowedModelOutput(
const inference::ModelOutput& io, const std::set<std::string>& allowed);
/// Validate that a model batch inputs and batch outputs are specified
/// correctly.
/// \param config The model configuration to validate..
/// \return The error status. A non-OK status indicates the batch inputs or
/// batch outputs are not valid.
Status ValidateBatchIO(const inference::ModelConfig& config);
/// Parse the 'value' of the parameter 'key' into a boolean value.
/// \param key The name of the parameter.
/// \param value The value of the parameter in string.
/// \param parsed_value Return the boolean of the parameter.
/// \return The error status. A non-OK status indicates failure on parsing the
/// value.
Status ParseBoolParameter(
const std::string& key, std::string value, bool* parsed_value);
/// Parse the 'value' of the parameter 'key' into a long long integer value.
/// \param key The name of the parameter.
/// \param value The value of the parameter in string.
/// \param parsed_value Return the numerical value of the parameter.
/// \return The error status. A non-OK status indicates failure on parsing the
/// value.
Status ParseLongLongParameter(
const std::string& key, const std::string& value, int64_t* parsed_value);
/// Obtain the 'profile_index' of the 'profile_name'.
/// \param profile_name The name of the profile.
/// \param profile_index Return the index of the profile.
/// \return The error status. A non-OK status indicates failure on getting the
/// value.
Status GetProfileIndex(const std::string& profile_name, int* profile_index);
/// Convert a model configuration protobuf to the equivalent json.
/// \param config The protobuf model configuration.
/// \param config_version The model configuration will be returned in
/// a format matching this version. If the configuration cannot be
/// represented in the requested version's format then an error will
/// be returned.
/// \param json Returns the equivalent JSON.
/// \return The error status.
Status ModelConfigToJson(
const inference::ModelConfig& config, const uint32_t config_version,
std::string* json_str);
/// Convert a model configuration JSON to the equivalent protobuf.
/// \param config The JSON model configuration.
/// \param config_version The model configuration will be returned in
/// a format matching this version. If the configuration cannot be
/// represented in the requested version's format then an error will
/// be returned.
/// \param protobuf Returns the equivalent protobuf.
/// \return The error status.
Status JsonToModelConfig(
const std::string& json_config, const uint32_t config_version,
inference::ModelConfig* protobuf_config);
/// Get the BackendType value for a platform name.
/// \param platform_name The platform name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType GetBackendTypeFromPlatform(const std::string& platform_name);
/// Get the BackendType value for a backend name.
/// \param backend_name The backend name.
/// \return The BackendType or BackendType::UNKNOWN if the platform string
/// is not recognized.
BackendType GetBackendType(const std::string& backend_name);
/// Get the Triton server data type corresponding to a data type.
/// \param dtype The data type.
/// \return The Triton server data type.
TRITONSERVER_DataType DataTypeToTriton(const inference::DataType dtype);
/// Get the data type corresponding to a Triton server data type.
/// \param dtype The Triton server data type.
/// \return The data type.
inference::DataType TritonToDataType(const TRITONSERVER_DataType dtype);
}} // namespace triton::core
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_lifecycle.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "repo_agent.h"
#include "triton/common/logging.h"
#include "triton/common/thread_pool.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace triton { namespace core {
const std::string&
ModelReadyStateString(ModelReadyState state)
{
switch (state) {
case ModelReadyState::UNKNOWN: {
static std::string m("UNKNOWN");
return m;
}
case ModelReadyState::READY: {
static std::string m("READY");
return m;
}
case ModelReadyState::UNAVAILABLE: {
static std::string m("UNAVAILABLE");
return m;
}
case ModelReadyState::LOADING: {
static std::string m("LOADING");
return m;
}
case ModelReadyState::UNLOADING: {
static std::string m("UNLOADING");
return m;
}
}
static std::string m("<unknown>");
return m;
}
namespace {
Status
VersionsToLoad(
const std::string model_path, const std::string& name,
const inference::ModelConfig& model_config, std::set<int64_t>* versions)
{
versions->clear();
// Get integral number of the version directory
std::set<std::string> subdirs;
RETURN_IF_ERROR(GetDirectorySubdirs(model_path, &subdirs));
std::set<int64_t, std::greater<int64_t>> existing_versions;
for (const auto& subdir : subdirs) {
if (subdir == kWarmupDataFolder || subdir == kInitialStateFolder) {
continue;
}
if ((subdir.length() > 1) && (subdir.front() == '0')) {
LOG_WARNING << "ignore version directory '" << subdir
<< "' which contains leading zeros in its directory name";
continue;
}
try {
int64_t version = std::stoll(subdir);
existing_versions.insert(version);
}
catch (const std::invalid_argument& ia) {
LOG_WARNING << "ignore version directory '" << subdir
<< "' which fails to convert to integral number";
}
}
if (model_config.version_policy().has_specific()) {
for (const auto& v : model_config.version_policy().specific().versions()) {
// Only load the specific versions that are presented in model directory
bool version_not_exist = existing_versions.insert(v).second;
if (!version_not_exist) {
versions->emplace(v);
} else {
LOG_ERROR << "version " << v << " is specified for model '" << name
<< "', but the version directory is not present";
}
}
} else {
if (model_config.version_policy().has_latest()) {
// std::set is sorted with std::greater
for (const auto& v : existing_versions) {
if (versions->size() >=
model_config.version_policy().latest().num_versions()) {
break;
}
versions->emplace(v);
}
} else {
// all
versions->insert(existing_versions.begin(), existing_versions.end());
}
}
return Status::Success;
}
// Use smart pointer with custom deleter so that model state will be updated
// to UNAVAILABLE if all smart pointer copies are out of scope
struct ModelDeleter {
ModelDeleter(std::function<void()> OnDestroyModel)
: OnDestroyModel_(std::move(OnDestroyModel))
{
}
void operator()(Model* model)
{
// The actual model object must be destroyed in a different
// thread. This thread could have a callstack that includes the
// model itself because this deleter could be triggered by
// a request release or response send in the model. Following
// delete will lead to the model destructor which may wait on this
// same thread... so deadlock if we don't use a different thread
// here.
std::function<void()> destroy_fn = OnDestroyModel_;
std::thread dthd([model, destroy_fn]() {
delete model;
destroy_fn();
});
dthd.detach();
}
// Use to inform the ModelLifeCycle that the model handle is destroyed
std::function<void()> OnDestroyModel_;
};
} // namespace
Status
ModelLifeCycle::Create(
InferenceServer* server, const ModelLifeCycleOptions& options,
std::unique_ptr<ModelLifeCycle>* life_cycle)
{
std::unique_ptr<ModelLifeCycle> local_life_cycle(
new ModelLifeCycle(server, options));
*life_cycle = std::move(local_life_cycle);
return Status::Success;
}
const ModelStateMap
ModelLifeCycle::LiveModelStates(bool strict_readiness)
{
LOG_VERBOSE(2) << "LiveModelStates()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
ModelStateMap live_model_states;
for (auto& model_version : map_) {
bool live = false;
VersionStateMap version_map;
for (auto& version_model : model_version.second) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (strict_readiness &&
version_model.second->state_ != ModelReadyState::READY) {
continue;
}
// At least one version is live (ready / loading / unloading)
if ((version_model.second->state_ != ModelReadyState::UNKNOWN) &&
(version_model.second->state_ != ModelReadyState::UNAVAILABLE)) {
live = true;
version_map[version_model.first] = std::make_pair(
version_model.second->state_, version_model.second->state_reason_);
}
}
if (live) {
live_model_states[model_version.first] = std::move(version_map);
}
}
return live_model_states;
}
Status
ModelLifeCycle::StopAllModels()
{
LOG_VERBOSE(2) << "StopAllModels()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
for (auto& model_version : map_) {
for (auto& version_model : model_version.second) {
if (version_model.second != nullptr) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (version_model.second->model_ != nullptr) {
version_model.second->model_->Stop();
}
}
}
}
return Status::Success;
}
const std::set<std::tuple<std::string, int64_t, size_t>>
ModelLifeCycle::InflightStatus()
{
LOG_VERBOSE(2) << "InflightStatus()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
std::set<std::tuple<std::string, int64_t, size_t>> inflight_status;
for (auto& model_version : map_) {
for (auto& version_model : model_version.second) {
if (version_model.second != nullptr) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (version_model.second->model_ != nullptr) {
const auto cnt =
version_model.second->model_->InflightInferenceCount();
if (cnt != 0) {
inflight_status.emplace(
model_version.first, version_model.first, cnt);
}
}
}
}
}
return inflight_status;
}
const ModelStateMap
ModelLifeCycle::ModelStates()
{
LOG_VERBOSE(2) << "ModelStates()";
std::lock_guard<std::mutex> map_lock(map_mtx_);
ModelStateMap model_states;
for (auto& model_version : map_) {
VersionStateMap version_map;
for (auto& version_model : model_version.second) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
version_map[version_model.first] = std::make_pair(
version_model.second->state_, version_model.second->state_reason_);
}
model_states[model_version.first] = std::move(version_map);
}
return model_states;
}
const VersionStateMap
ModelLifeCycle::VersionStates(const std::string& model_name)
{
LOG_VERBOSE(2) << "VersionStates() '" << model_name << "'";
std::lock_guard<std::mutex> map_lock(map_mtx_);
VersionStateMap version_map;
auto mit = map_.find(model_name);
if (mit != map_.end()) {
for (auto& version_model : mit->second) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
version_map[version_model.first] = std::make_pair(
version_model.second->state_, version_model.second->state_reason_);
}
}
return version_map;
}
Status
ModelLifeCycle::ModelState(
const std::string& model_name, const int64_t model_version,
ModelReadyState* state)
{
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto mit = map_.find(model_name);
if (mit != map_.end()) {
auto vit = mit->second.find(model_version);
if (vit != mit->second.end()) {
std::lock_guard<std::mutex> lock(vit->second->mtx_);
*state = vit->second->state_;
return Status::Success;
}
}
return Status(
Status::Code::NOT_FOUND, "model '" + model_name + "', version " +
std::to_string(model_version) +
" is not found");
}
Status
ModelLifeCycle::GetModel(
const std::string& model_name, const int64_t version,
std::shared_ptr<Model>* model)
{
LOG_VERBOSE(2) << "GetModel() '" << model_name << "' version " << version;
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto mit = map_.find(model_name);
if (mit == map_.end()) {
return Status(Status::Code::NOT_FOUND, "'" + model_name + "' is not found");
}
auto vit = mit->second.find(version);
if (vit == mit->second.end()) {
if (version != -1) {
return Status(
Status::Code::NOT_FOUND, "'" + model_name + "' version " +
std::to_string(version) +
" is not found");
}
// The case where the request is asking for latest version
int64_t latest = -1;
for (auto& version_model : mit->second) {
if (version_model.first > latest) {
std::lock_guard<std::mutex> lock(version_model.second->mtx_);
if (version_model.second->state_ == ModelReadyState::READY) {
latest = version_model.first;
// Tedious, but have to set handle for any "latest" version
// at the moment to avoid edge case like the following:
// "versions : 1 3 2", version 3 is latest but is requested
// to be unloaded when the iterator is examining version 2,
// then 'model' will ensure version 3 is still valid
*model = version_model.second->model_;
}
}
}
if (latest == -1) {
return Status(
Status::Code::NOT_FOUND,
"'" + model_name + "' has no available versions");
}
} else {
std::lock_guard<std::mutex> lock(vit->second->mtx_);
if (vit->second->state_ == ModelReadyState::READY) {
*model = vit->second->model_;
} else {
return Status(
Status::Code::UNAVAILABLE, "'" + model_name + "' version " +
std::to_string(version) +
" is not at ready state");
}
}
return Status::Success;
}
Status
ModelLifeCycle::AsyncUnload(const std::string& model_name)
{
LOG_VERBOSE(2) << "AsyncUnload() '" << model_name << "'";
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto it = map_.find(model_name);
if (it == map_.end()) {
return Status(
Status::Code::INVALID_ARG, "Model to be unloaded has not been served");
}
// Get the existing agent models and notify the unload action
const uint64_t now_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
for (auto& version : it->second) {
auto& model_info = version.second;
std::lock_guard<std::mutex> lock(model_info->mtx_);
model_info->last_update_ns_ = now_ns;
// Unload serving model, for model that is in LOADING state,
// the updated timestamp will be recognized that there is newer update
// on the model info and the load should be aborted
if (model_info->state_ == ModelReadyState::READY) {
if (model_info->agent_model_list_ != nullptr) {
// Only log the error because the model should be unloaded regardless
auto status = model_info->agent_model_list_->InvokeAgentModels(
TRITONREPOAGENT_ACTION_UNLOAD);
if (!status.IsOk()) {
LOG_ERROR
<< "Agent model returns error on TRITONREPOAGENT_ACTION_UNLOAD: "
<< status.AsString();
}
}
// unload
model_info->Release();
}
}
return Status::Success;
}
Status
ModelLifeCycle::AsyncLoad(
const std::string& model_name, const std::string& model_path,
const inference::ModelConfig& model_config, const bool is_config_provided,
const std::shared_ptr<TritonRepoAgentModelList>& agent_model_list,
std::function<void(Status)>&& OnComplete)
{
LOG_VERBOSE(2) << "AsyncLoad() '" << model_name << "'";
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto it = map_.find(model_name);
if (it == map_.end()) {
it = map_.emplace(std::make_pair(model_name, VersionMap())).first;
}
std::set<int64_t> versions;
RETURN_IF_ERROR(
VersionsToLoad(model_path, model_name, model_config, &versions));
if (versions.empty()) {
return Status(
Status::Code::INVALID_ARG,
"at least one version must be available under the version policy of "
"model '" +
model_name + "'");
}
const uint64_t now_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
std::shared_ptr<LoadTracker> load_tracker(
new LoadTracker(versions.size(), now_ns));
for (const auto& version : versions) {
std::unique_ptr<ModelInfo> linfo(
new ModelInfo(model_path, model_config, now_ns));
ModelInfo* model_info = linfo.get();
LOG_INFO << "loading: " << model_name << ":" << version;
model_info->state_ = ModelReadyState::LOADING;
model_info->state_reason_.clear();
model_info->agent_model_list_ = agent_model_list;
auto res = it->second.emplace(
std::make_pair(version, std::unique_ptr<ModelInfo>()));
if (res.second) {
res.first->second = std::move(linfo);
} else {
// There is already a record of this model version. Check if the version
// model is being served, if so, the re-load of the version
// should be performed in background to avoid version downtime.
// Otherwise, swap and monitor state for newly loading model.
auto& serving_model = res.first->second;
std::lock_guard<std::mutex> lock(serving_model->mtx_);
if (serving_model->state_ == ModelReadyState::READY) {
background_models_[(uintptr_t)model_info] = std::move(linfo);
} else {
// swap the monitoring model info
serving_model.swap(linfo);
// further check the state, put to 'background_models_' to keep
// the object valid if the model is LOADING / UNLOADING, because
// the model info will be accessed by a different thread once the
// operation is completed
if ((linfo->state_ == ModelReadyState::LOADING) ||
(linfo->state_ == ModelReadyState::UNLOADING)) {
ModelInfo* key = linfo.get();
background_models_[(uintptr_t)key] = std::move(linfo);
}
}
}
// Load model asynchronously via thread pool
load_pool_->Enqueue([this, model_name, version, model_info, OnComplete,
load_tracker, is_config_provided]() {
CreateModel(model_name, version, model_info, is_config_provided);
OnLoadComplete(model_name, version, model_info, OnComplete, load_tracker);
});
}
return Status::Success;
}
void
ModelLifeCycle::CreateModel(
const std::string& model_name, const int64_t version, ModelInfo* model_info,
const bool is_config_provided)
{
LOG_VERBOSE(2) << "CreateModel() '" << model_name << "' version " << version;
const auto& model_config = model_info->model_config_;
// Create model
Status status;
std::unique_ptr<Model> is;
// If 'backend' is specified in the config then use the new triton
// backend.
if (!model_config.backend().empty()) {
std::unique_ptr<TritonModel> model;
status = TritonModel::Create(
server_, model_info->model_path_, cmdline_config_map_, host_policy_map_,
model_name, version, model_config, is_config_provided, &model);
is.reset(model.release());
} else {
#ifdef TRITON_ENABLE_ENSEMBLE
if (model_info->is_ensemble_) {
status = EnsembleModel::Create(
server_, model_info->model_path_, version, model_config,
is_config_provided, min_compute_capability_, &is);
// Complete label provider with label information from involved models
// Must be done here because involved models may not be able to
// obtained from server because this may happen during server
// initialization.
if (status.IsOk()) {
std::set<std::string> no_label_outputs;
const auto& label_provider = is->GetLabelProvider();
for (const auto& output : model_config.output()) {
if (label_provider->GetLabel(output.name(), 0).empty()) {
no_label_outputs.emplace(output.name());
}
}
for (const auto& element : model_config.ensemble_scheduling().step()) {
for (const auto& pair : element.output_map()) {
// Found model that produce one of the missing output
if (no_label_outputs.find(pair.second) != no_label_outputs.end()) {
std::shared_ptr<Model> model;
// Safe to obtain model because the ensemble can't be loaded
// until the involved models are ready
GetModel(element.model_name(), element.model_version(), &model);
label_provider->AddLabels(
pair.second,
model->GetLabelProvider()->GetLabels(pair.first));
}
}
}
}
} else
#endif // TRITON_ENABLE_ENSEMBLE
{
status = Status(
Status::Code::INVALID_ARG,
"unknown platform '" + model_config.platform() + "'");
}
}
std::lock_guard<std::mutex> lock(model_info->mtx_);
if (status.IsOk()) {
// [FIXME] better way to manage agent model lifecycle
// Let the deleter also holds a shared pointer copy of agent model list,
// because the reference in ModelInfo can be cleared before the Model object
// is destroyed, and we want agent model to be valid for receiving
// UNLOAD_COMPLETE signal (see ~TritonRepoAgentModelList for detail)
auto agent_model_list = model_info->agent_model_list_;
model_info->model_.reset(
is.release(), ModelDeleter([this, model_name, version, model_info,
agent_model_list]() mutable {
LOG_VERBOSE(2) << "OnDestroy callback() '" << model_name
<< "' version " << version;
LOG_INFO << "successfully unloaded '" << model_name << "' version "
<< version;
// Update model state as it is fully unloaded
{
std::lock_guard<std::mutex> lock(model_info->mtx_);
model_info->state_ = ModelReadyState::UNAVAILABLE;
model_info->state_reason_ = "unloaded";
}
// Check if the model info is in background, if so, remove from the
// map
std::lock_guard<std::mutex> lk(this->map_mtx_);
auto it = this->background_models_.find((uintptr_t)model_info);
if (it != this->background_models_.end()) {
this->background_models_.erase(it);
}
}));
} else {
LOG_ERROR << "failed to load '" << model_name << "' version " << version
<< ": " << status.AsString();
model_info->state_ = ModelReadyState::UNAVAILABLE;
model_info->state_reason_ = status.AsString();
}
}
void
ModelLifeCycle::OnLoadComplete(
const std::string& model_name, const int64_t version, ModelInfo* model_info,
std::function<void(Status)> OnComplete,
std::shared_ptr<LoadTracker> load_tracker)
{
std::lock_guard<std::mutex> tracker_lock(load_tracker->mtx_);
++load_tracker->completed_version_cnt_;
load_tracker->load_set_[version] = model_info;
// Version will not be marked ready until all versions are
// ready, this simplify the unloading when one version fails to load as
// all other versions won't have inflight requests
if (model_info->state_ != ModelReadyState::LOADING) {
load_tracker->load_failed_ = true;
load_tracker->reason_ +=
("version " + std::to_string(version) + " is at " +
ModelReadyStateString(model_info->state_) +
" state: " + model_info->state_reason_ + ";");
}
// Check if all versions are completed and finish the load
if (load_tracker->completed_version_cnt_ ==
load_tracker->affected_version_cnt_) {
// hold 'map_mtx_' as there will be change onto the model info map
std::lock_guard<std::mutex> map_lock(map_mtx_);
auto it = map_.find(model_name);
// Check if the load is the latest frontground action on the model
for (const auto& version_info : it->second) {
if (version_info.second->last_update_ns_ >
load_tracker->last_update_ns_) {
load_tracker->load_failed_ = true;
load_tracker->reason_ =
"Newer operation has been applied to the model lifecycle, current "
"load operation is out-dated.";
break;
}
}
if (load_tracker->load_failed_) {
// Move agent list out of ModelInfo as it needs to be invoked
// after all ModelInfos are reset
std::shared_ptr<TritonRepoAgentModelList> lagent_list;
if (model_info->agent_model_list_) {
lagent_list = std::move(model_info->agent_model_list_);
}
// If any of the versions fails to load, abort the load and unload
// all newly loaded versions
for (auto& loaded : load_tracker->load_set_) {
// Unload directly, the object is being managed either in frontground
// or background
std::lock_guard<std::mutex> lock(loaded.second->mtx_);
if (loaded.second->model_ != nullptr) {
loaded.second->Release();
}
}
if (lagent_list) {
auto status =
lagent_list->InvokeAgentModels(TRITONREPOAGENT_ACTION_LOAD_FAIL);
if (!status.IsOk()) {
LOG_ERROR << "Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_FAIL: "
<< status.AsString();
}
}
} else {
// Unload any previous loaded versions that are still available
for (auto& version_info : it->second) {
auto& mi = version_info.second;
std::lock_guard<std::mutex> info_lk(mi->mtx_);
if ((mi->state_ == ModelReadyState::READY) &&
(mi->last_update_ns_ < load_tracker->last_update_ns_)) {
if (mi->agent_model_list_ != nullptr) {
auto status = mi->agent_model_list_->InvokeAgentModels(
TRITONREPOAGENT_ACTION_UNLOAD);
if (!status.IsOk()) {
LOG_ERROR << "Agent model returns error on "
"TRITONREPOAGENT_ACTION_UNLOAD: "
<< status.AsString();
}
}
mi->Release();
}
}
// Mark current versions ready and track info in foreground
for (auto& loaded : load_tracker->load_set_) {
std::lock_guard<std::mutex> curr_info_lk(loaded.second->mtx_);
loaded.second->state_ = ModelReadyState::READY;
model_info->state_reason_.clear();
LOG_INFO << "successfully loaded '" << model_name << "' version "
<< version;
auto bit = background_models_.find((uintptr_t)loaded.second);
// Check if the version model is loaded in background, if so,
// replace and unload the current serving version
if (bit != background_models_.end()) {
auto vit = it->second.find(loaded.first);
// Need to lock the previous model info for in case the model is
// loading / unloading, this ensure the model state is consistent
// even when the load / unload is completed.
std::lock_guard<std::mutex> prev_info_lk(vit->second->mtx_);
// swap previous info into local unique pointer
auto linfo = std::move(bit->second);
vit->second.swap(linfo);
background_models_.erase(bit);
// if previous info is under change, put into 'background_models_'
if ((linfo->state_ == ModelReadyState::LOADING) ||
(linfo->state_ == ModelReadyState::UNLOADING)) {
ModelInfo* key = linfo.get();
background_models_[(uintptr_t)key] = std::move(linfo);
}
}
}
if (model_info->agent_model_list_) {
auto status = model_info->agent_model_list_->InvokeAgentModels(
TRITONREPOAGENT_ACTION_LOAD_COMPLETE);
if (!status.IsOk()) {
LOG_ERROR << "Agent model returns error on "
"TRITONREPOAGENT_ACTION_LOAD_COMPLETE: "
<< status.AsString();
}
}
}
if (OnComplete != nullptr) {
OnComplete(
load_tracker->load_failed_
? Status(Status::Code::INVALID_ARG, load_tracker->reason_)
: Status::Success);
}
}
}
}} // namespace triton::core
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#pragma once
#include <functional>
#include <map>
#include <mutex>
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "repo_agent.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "triton/common/thread_pool.h"
namespace triton { namespace core {
struct ModelLifeCycleOptions {
explicit ModelLifeCycleOptions(
const double min_compute_capability,
const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map,
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map,
const unsigned int model_load_thread_count)
: min_compute_capability_(min_compute_capability),
backend_cmdline_config_map_(backend_cmdline_config_map),
host_policy_map_(host_policy_map),
model_load_thread_count_(model_load_thread_count)
{
}
// The minimum supported CUDA compute capability.
const double min_compute_capability_;
// The backend configuration settings specified on the command-line
const triton::common::BackendCmdlineConfigMap& backend_cmdline_config_map_;
// The host policy setting used when loading models.
const triton::common::HostPolicyCmdlineConfigMap& host_policy_map_;
// Number of the threads to use for concurrently loading models
const unsigned int model_load_thread_count_;
};
/// Readiness status for models.
enum class ModelReadyState {
// The model is in an unknown state. The model is not available for
// inferencing.
UNKNOWN,
// The model is ready and available for inferencing.
READY,
// The model is unavailable, indicating that the model failed to
// load or has been implicitly or explicitly unloaded. The model is
// not available for inferencing.
UNAVAILABLE,
// The model is being loaded by the inference server. The model is
// not available for inferencing.
LOADING,
// The model is being unloaded by the inference server. The model is
// not available for inferencing.
UNLOADING
};
/// Get the string representation for a ModelReadyState
const std::string& ModelReadyStateString(ModelReadyState state);
using VersionStateMap =
std::map<int64_t, std::pair<ModelReadyState, std::string>>;
using ModelStateMap = std::map<std::string, VersionStateMap>;
// Helper class to manage the lifecycle of a list of associated agent models
class TritonRepoAgentModelList {
public:
TritonRepoAgentModelList()
: last_action_type_(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE){};
~TritonRepoAgentModelList()
{
// Using destructor to finish the unload lifecycle without
// explicitly managing the last step in ModelLifecycle.
if (last_action_type_ == TRITONREPOAGENT_ACTION_UNLOAD) {
InvokeAgentModels(TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE);
}
}
Status AddAgentModel(std::unique_ptr<TritonRepoAgentModel>&& agent_model)
{
agent_models_.emplace_back(std::move(agent_model));
return Status::Success;
}
size_t Size() { return agent_models_.size(); }
TritonRepoAgentModel* Back() { return agent_models_.back().get(); }
Status InvokeAgentModels(const TRITONREPOAGENT_ActionType action_type)
{
// Special handling for the current model lifecycle implementation,
// the repo agent may be asked to perform UNLOAD action multiple times,
// and the requests after the first should be ignored.
const bool first_unload =
(action_type == TRITONREPOAGENT_ACTION_UNLOAD) &&
(last_action_type_ != TRITONREPOAGENT_ACTION_UNLOAD);
if (!first_unload) {
return Status::Success;
}
last_action_type_ = action_type;
switch (action_type) {
case TRITONREPOAGENT_ACTION_LOAD:
case TRITONREPOAGENT_ACTION_UNLOAD: {
for (size_t idx = 0; idx < agent_models_.size(); ++idx) {
RETURN_IF_ERROR(agent_models_[idx]->InvokeAgent(action_type));
}
break;
}
case TRITONREPOAGENT_ACTION_LOAD_COMPLETE:
case TRITONREPOAGENT_ACTION_LOAD_FAIL:
case TRITONREPOAGENT_ACTION_UNLOAD_COMPLETE: {
// reverse order
for (size_t one_pass_idx = agent_models_.size(); one_pass_idx > 0;
--one_pass_idx) {
RETURN_IF_ERROR(
agent_models_[one_pass_idx - 1]->InvokeAgent(action_type));
}
break;
}
}
return Status::Success;
}
private:
DISALLOW_COPY_AND_ASSIGN(TritonRepoAgentModelList);
std::vector<std::unique_ptr<TritonRepoAgentModel>> agent_models_;
TRITONREPOAGENT_ActionType last_action_type_;
};
class InferenceServer;
class Model;
class ModelLifeCycle {
public:
static Status Create(
InferenceServer* server, const ModelLifeCycleOptions& options,
std::unique_ptr<ModelLifeCycle>* life_cycle);
~ModelLifeCycle()
{
// Explicitly clean up thread pool first to clean up any pending callbacks
// that may modify model lifecycle members
load_pool_.reset();
map_.clear();
}
// Start loading model with specified versions asynchronously.
// All versions that are being served will be unloaded only after
// the load is finished sucessfully.
Status AsyncLoad(
const std::string& model_name, const std::string& model_path,
const inference::ModelConfig& model_config, const bool is_config_provided,
const std::shared_ptr<TritonRepoAgentModelList>& agent_model_list,
std::function<void(Status)>&& OnComplete);
// Unload model asynchronously.
Status AsyncUnload(const std::string& model_name);
// Get specified version of the model. Latest ready version will
// be retrieved if 'version' is -1. Return error if the version specified is
// not found or it is not ready.
Status GetModel(
const std::string& model_name, const int64_t version,
std::shared_ptr<Model>* model);
// Get the ModelStateMap representation of the live models. A model is
// live if at least one of the versions is not unknown nor unavailable.
// If 'strict_readiness' is true, a model is only live if
// at least one of the versions is ready.
const ModelStateMap LiveModelStates(bool strict_readiness = false);
// Get the ModelStateMap representation of the models.
const ModelStateMap ModelStates();
// Get the VersionStateMap representation of the specified model.
const VersionStateMap VersionStates(const std::string& model_name);
// Get the state of a specific model version.
Status ModelState(
const std::string& model_name, const int64_t model_version,
ModelReadyState* state);
// Instruct the model to stop accepting new inference requests.
Status StopAllModels();
// Return the number of in-flight inference if any, model versions
// that don't have in-flight inferences will not be included.
const std::set<std::tuple<std::string, int64_t, size_t>> InflightStatus();
private:
struct ModelInfo {
ModelInfo(
const std::string& model_path,
const inference::ModelConfig& model_config,
const uint64_t last_update_ns)
: model_config_(model_config), model_path_(model_path),
#ifdef TRITON_ENABLE_ENSEMBLE
is_ensemble_(model_config.platform() == kEnsemblePlatform),
#else
is_ensemble_(false),
#endif // TRITON_ENABLE_ENSEMBLE
last_update_ns_(last_update_ns), state_(ModelReadyState::UNKNOWN)
{
}
// Release the flyweight in ModelInfo object, reflect as 'UNLOADING' in
// model state. Note that 'mtx_' should be acquired before invoking this
// function to prevent possible data race.
void Release()
{
state_ = ModelReadyState::UNLOADING;
state_reason_.clear();
agent_model_list_.reset();
model_.reset();
}
const inference::ModelConfig model_config_;
const std::string model_path_;
const bool is_ensemble_;
std::mutex mtx_;
uint64_t last_update_ns_;
ModelReadyState state_;
std::string state_reason_;
// flyweight
std::shared_ptr<TritonRepoAgentModelList> agent_model_list_;
std::shared_ptr<Model> model_;
};
struct LoadTracker {
LoadTracker(
const size_t affected_version_cnt, const uint64_t last_update_ns)
: last_update_ns_(last_update_ns),
affected_version_cnt_(affected_version_cnt), load_failed_(false),
completed_version_cnt_(0)
{
}
const uint64_t last_update_ns_;
const size_t affected_version_cnt_;
std::mutex mtx_;
bool load_failed_;
std::string reason_;
size_t completed_version_cnt_;
std::map<int64_t, ModelInfo*> load_set_;
};
ModelLifeCycle(InferenceServer* server, const ModelLifeCycleOptions& options)
: server_(server),
min_compute_capability_(options.min_compute_capability_),
cmdline_config_map_(options.backend_cmdline_config_map_),
host_policy_map_(options.host_policy_map_)
{
load_pool_.reset(new triton::common::ThreadPool(
std::max(1u, options.model_load_thread_count_)));
}
void CreateModel(
const std::string& model_name, const int64_t version,
ModelInfo* model_info, const bool is_config_provided);
// Callback function template for model load.
// 'OnComplete' needs to be passed by value for now as there can be
// multiple versions to be loaded and each holds a copy of
// the 'OnComplete' callback.
void OnLoadComplete(
const std::string& model_name, const int64_t version,
ModelInfo* model_info, std::function<void(Status)> OnComplete,
std::shared_ptr<LoadTracker> load_tracker);
// Mutex for 'map_' and 'background_models_'
std::mutex map_mtx_;
using VersionMap = std::map<int64_t, std::unique_ptr<ModelInfo>>;
using ModelMap = std::map<std::string, VersionMap>;
ModelMap map_;
// Models that are being loaded / unloaded in background
std::map<uintptr_t, std::unique_ptr<ModelInfo>> background_models_;
InferenceServer* server_;
const double min_compute_capability_;
const triton::common::BackendCmdlineConfigMap cmdline_config_map_;
const triton::common::HostPolicyCmdlineConfigMap host_policy_map_;
// Fixed-size thread pool to load models at specified concurrency
std::unique_ptr<triton::common::ThreadPool> load_pool_;
};
}} // namespace triton::core
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
//
#include "model_repository_manager.h"
#include <algorithm>
#include <deque>
#include <future>
#include <stdexcept>
#include <thread>
#include "constants.h"
#include "ensemble_utils.h"
#include "filesystem.h"
#include "model.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
#include "backend_model.h"
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_model.h"
#endif // TRITON_ENABLE_ENSEMBLE
namespace triton { namespace core {
namespace {
static std::string file_prefix = "file:";
// Internal repo agent used for model file override
class LocalizeRepoAgent : public TritonRepoAgent {
public:
LocalizeRepoAgent()
: TritonRepoAgent("ModelRepositoryManager::LocalizeRepoAgent")
{
// Callbacks below interact with TritonRepoAgentModel directly knowing that
// it is the internal implementation of TRITONREPOAGENT_AgentModel
model_action_fn_ = [](TRITONREPOAGENT_Agent* agent,
TRITONREPOAGENT_AgentModel* model,
const TRITONREPOAGENT_ActionType action_type)
-> TRITONSERVER_Error* {
auto agent_model = reinterpret_cast<TritonRepoAgentModel*>(model);
switch (action_type) {
case TRITONREPOAGENT_ACTION_LOAD: {
// localize the override files for model loading,
// as currently the model is expected to load from local directory
const char* temp_dir_cstr = nullptr;
RETURN_TRITONSERVER_ERROR_IF_ERROR(
agent_model->AcquireMutableLocation(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM, &temp_dir_cstr));
const std::string temp_dir = temp_dir_cstr;
const auto& files =
*reinterpret_cast<std::vector<const InferenceParameter*>*>(
agent_model->State());
bool found_config = false;
for (const auto& file : files) {
if (file->Name() == "config") {
if (file->Type() != TRITONSERVER_PARAMETER_STRING) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"Config parameter 'config' must have string type for its "
"value");
}
inference::ModelConfig config;
RETURN_TRITONSERVER_ERROR_IF_ERROR(JsonToModelConfig(
file->ValueString(), 1 /* config_version */, &config));
RETURN_TRITONSERVER_ERROR_IF_ERROR(WriteTextProto(
JoinPath({temp_dir, kModelConfigPbTxt}), config));
found_config = true;
} else if (file->Name().rfind(file_prefix, 0) == 0) {
if (file->Type() != TRITONSERVER_PARAMETER_BYTES) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("File parameter '") + file->Name() +
"' must have bytes type for its value")
.c_str());
}
// Save model file to the instructed directory
// mkdir
const std::string file_path =
JoinPath({temp_dir, file->Name().substr(file_prefix.size())});
const std::string dir = DirName(file_path);
bool dir_exist = false;
RETURN_TRITONSERVER_ERROR_IF_ERROR(FileExists(dir, &dir_exist));
if (dir_exist) {
bool is_dir = false;
RETURN_TRITONSERVER_ERROR_IF_ERROR(IsDirectory(dir, &is_dir));
if (!is_dir) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
(std::string("Invalid file parameter '") + file->Name() +
"', directory has been created as a file")
.c_str());
}
} else {
RETURN_TRITONSERVER_ERROR_IF_ERROR(
MakeDirectory(dir, true /* recursive */));
}
// write
RETURN_TRITONSERVER_ERROR_IF_ERROR(WriteBinaryFile(
file_path,
reinterpret_cast<const char*>(file->ValuePointer()),
file->ValueByteSize()));
}
}
if (!found_config) {
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"Load parameter 'config' must be specified for model file "
"override");
}
// Commit the temporary directory
RETURN_TRITONSERVER_ERROR_IF_ERROR(agent_model->SetLocation(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM, temp_dir_cstr));
break;
}
default:
break;
}
return nullptr; // success
};
model_fini_fn_ =
[](TRITONREPOAGENT_Agent* agent,
TRITONREPOAGENT_AgentModel* model) -> TRITONSERVER_Error* {
auto agent_model = reinterpret_cast<TritonRepoAgentModel*>(model);
RETURN_TRITONSERVER_ERROR_IF_ERROR(agent_model->DeleteMutableLocation());
return nullptr; // success
};
}
};
Status
CreateAgentModelListWithLoadAction(
const inference::ModelConfig& original_model_config,
const std::string& original_model_path,
std::shared_ptr<TritonRepoAgentModelList>* agent_model_list)
{
if (original_model_config.has_model_repository_agents()) {
// Trick to append user specified repo agent on top of internal ones
std::shared_ptr<TritonRepoAgentModelList> lagent_model_list;
if (*agent_model_list != nullptr) {
lagent_model_list = std::move(*agent_model_list);
} else {
lagent_model_list.reset(new TritonRepoAgentModelList());
}
FileSystemType filesystem_type;
RETURN_IF_ERROR(GetFileSystemType(original_model_path, &filesystem_type));
TRITONREPOAGENT_ArtifactType artifact_type =
TRITONREPOAGENT_ARTIFACT_FILESYSTEM;
if (filesystem_type != FileSystemType::LOCAL) {
artifact_type = TRITONREPOAGENT_ARTIFACT_REMOTE_FILESYSTEM;
}
const char* location = original_model_path.c_str();
inference::ModelConfig model_config = original_model_config;
for (const auto& agent_config :
original_model_config.model_repository_agents().agents()) {
std::shared_ptr<TritonRepoAgent> agent;
RETURN_IF_ERROR(
TritonRepoAgentManager::CreateAgent(agent_config.name(), &agent));
TritonRepoAgent::Parameters agent_params;
for (const auto& parameter : agent_config.parameters()) {
agent_params.emplace_back(parameter.first, parameter.second);
}
std::unique_ptr<TritonRepoAgentModel> agent_model;
if (lagent_model_list->Size() != 0) {
lagent_model_list->Back()->Location(&artifact_type, &location);
const auto config_path = JoinPath({location, kModelConfigPbTxt});
if (!ReadTextProto(config_path, &model_config).IsOk()) {
model_config.Clear();
}
}
RETURN_IF_ERROR(TritonRepoAgentModel::Create(
artifact_type, location, model_config, agent, agent_params,
&agent_model));
RETURN_IF_ERROR(agent_model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD));
lagent_model_list->AddAgentModel(std::move(agent_model));
}
*agent_model_list = std::move(lagent_model_list);
}
return Status::Success;
}
int64_t
GetModifiedTime(const std::string& path)
{
// If there is an error in any step the fall-back default
// modification time is 0. This means that in error cases 'path'
// will show as not modified. This is the safe fall-back to avoid
// assuming a model is constantly being modified.
bool path_is_dir;
Status status = IsDirectory(path, &path_is_dir);
if (!status.IsOk()) {
LOG_ERROR << "Failed to determine modification time for '" << path
<< "': " << status.AsString();
return 0;
}
// If 'path' is a file return its mtime. Otherwise, using the modification
// time of the directory as baseline in case of file deletion
int64_t mtime = 0;
status = FileModificationTime(path, &mtime);
if (!status.IsOk()) {
LOG_ERROR << "Failed to determine modification time for '" << path
<< "': " << status.AsString();
return 0;
}
if (!path_is_dir) {
return mtime;
}
// 'path' is a directory. Return the most recent mtime of the
// contents of the directory.
std::set<std::string> contents;
status = GetDirectoryContents(path, &contents);
if (!status.IsOk()) {
LOG_ERROR << "Failed to determine modification time for '" << path
<< "': " << status.AsString();
return 0;
}
for (const auto& child : contents) {
const auto full_path = JoinPath({path, child});
mtime = std::max(mtime, GetModifiedTime(full_path));
}
return mtime;
}
// Return true if any file in the subdirectory root at 'path' has been
// modified more recently than 'last'. Return the most-recent modified
// time in 'last'.
bool
IsModified(const std::string& path, int64_t* last_ns)
{
const int64_t repo_ns = GetModifiedTime(path);
bool modified = repo_ns > *last_ns;
*last_ns = repo_ns;
return modified;
}
} // namespace
struct ModelRepositoryManager::ModelInfo {
ModelInfo(
const int64_t mtime_nsec, const int64_t prev_mtime_ns,
const std::string& model_path)
: mtime_nsec_(mtime_nsec), prev_mtime_ns_(prev_mtime_ns),
explicitly_load_(true), model_path_(model_path),
is_config_provided_(false)
{
}
ModelInfo()
: mtime_nsec_(0), prev_mtime_ns_(0), explicitly_load_(true),
is_config_provided_(false)
{
}
int64_t mtime_nsec_;
int64_t prev_mtime_ns_;
bool explicitly_load_;
inference::ModelConfig model_config_;
std::string model_path_;
// Temporary location to hold agent model list before creating the model
// the ownership must transfer to ModelLifeCycle to ensure
// the agent model life cycle is handled properly.
std::shared_ptr<TritonRepoAgentModelList> agent_model_list_;
bool is_config_provided_;
};
ModelRepositoryManager::ModelRepositoryManager(
const std::set<std::string>& repository_paths, const bool autofill,
const bool polling_enabled, const bool model_control_enabled,
const double min_compute_capability,
std::unique_ptr<ModelLifeCycle> life_cycle)
: repository_paths_(repository_paths), autofill_(autofill),
polling_enabled_(polling_enabled),
model_control_enabled_(model_control_enabled),
min_compute_capability_(min_compute_capability),
model_life_cycle_(std::move(life_cycle))
{
}
ModelRepositoryManager::~ModelRepositoryManager() {}
Status
ModelRepositoryManager::Create(
InferenceServer* server, const std::string& server_version,
const std::set<std::string>& repository_paths,
const std::set<std::string>& startup_models, const bool strict_model_config,
const bool polling_enabled, const bool model_control_enabled,
const ModelLifeCycleOptions& life_cycle_options,
std::unique_ptr<ModelRepositoryManager>* model_repository_manager)
{
// The rest only matters if repository path is valid directory
for (const auto& path : repository_paths) {
bool path_is_dir;
RETURN_IF_ERROR(IsDirectory(path, &path_is_dir));
if (!path_is_dir) {
return Status(
Status::Code::INVALID_ARG,
"repository path is not a valid directory");
}
}
if (polling_enabled && model_control_enabled) {
return Status(
Status::Code::INVALID_ARG,
"cannot enable both polling and explicit model control");
}
std::unique_ptr<ModelLifeCycle> life_cycle;
RETURN_IF_ERROR(
ModelLifeCycle::Create(server, life_cycle_options, &life_cycle));
// Not setting the smart pointer directly to simplify clean up
std::unique_ptr<ModelRepositoryManager> local_manager(
new ModelRepositoryManager(
repository_paths, !strict_model_config, polling_enabled,
model_control_enabled, life_cycle_options.min_compute_capability_,
std::move(life_cycle)));
*model_repository_manager = std::move(local_manager);
// Support loading all models on startup in explicit model control mode with
// special startup_model name "*". This does not imply support for pattern
// matching in model names.
bool load_all_models_on_startup = false;
if ((startup_models.find("*") != startup_models.end()) &&
model_control_enabled) {
if (startup_models.size() > 1) {
return Status(
Status::Code::INVALID_ARG,
"Wildcard model name '*' must be the ONLY startup model "
"if specified at all.");
}
load_all_models_on_startup = true;
}
bool all_models_polled = true;
if (!model_control_enabled || load_all_models_on_startup) {
// only error happens before model load / unload will be return
// model loading / unloading error will be printed but ignored
RETURN_IF_ERROR(
(*model_repository_manager)->PollAndUpdateInternal(&all_models_polled));
} else {
// Load each specified startup_model
std::unordered_map<std::string, std::vector<const InferenceParameter*>>
models;
for (const auto& model_name : startup_models) {
models[model_name];
}
RETURN_IF_ERROR(
(*model_repository_manager)
->LoadUnloadModels(
models, ActionType::LOAD, false, &all_models_polled));
}
if (!all_models_polled) {
return Status(Status::Code::INTERNAL, "failed to load all models");
}
// Some models may failed to be loaded after model manager is created,
// return proper error and let function caller decide whether to proceed.
for (const auto& model : (*model_repository_manager)->infos_) {
const auto version_states =
(*model_repository_manager)
->model_life_cycle_->VersionStates(model.first);
// Return general error message, detail of each model's loading state
// is logged separately.
if (version_states.empty()) {
return Status(Status::Code::INTERNAL, "failed to load all models");
}
for (const auto& state : version_states) {
if (state.second.first != ModelReadyState::READY) {
return Status(Status::Code::INTERNAL, "failed to load all models");
}
}
}
return Status::Success;
}
Status
ModelRepositoryManager::PollAndUpdate()
{
if (!polling_enabled_) {
return Status(Status::Code::UNAVAILABLE, "polling is disabled");
}
bool all_models_polled;
return PollAndUpdateInternal(&all_models_polled);
}
Status
ModelRepositoryManager::PollAndUpdateInternal(bool* all_models_polled)
{
// Serialize all operations that change model state
std::lock_guard<std::mutex> lock(poll_mu_);
std::set<std::string> added, deleted, modified, unmodified;
// We don't modify 'infos_' in place to minimize how long we need to
// hold the lock and also prevent any partial changes to do an error
// during processing.
ModelInfoMap new_infos;
// Each subdirectory of repository path is a model directory from
// which we read the model configuration.
std::unordered_map<std::string, std::vector<const InferenceParameter*>>
subdirs;
RETURN_IF_ERROR(Poll(
subdirs, &added, &deleted, &modified, &unmodified, &new_infos,
all_models_polled));
// Anything in 'infos_' that is not in "added", "modified", or
// "unmodified" is deleted.
for (const auto& pr : infos_) {
if ((added.find(pr.first) == added.end()) &&
(modified.find(pr.first) == modified.end()) &&
(unmodified.find(pr.first) == unmodified.end())) {
deleted.insert(pr.first);
}
}
// Nothing to do if no model adds, deletes or modifies.
if (added.empty() && deleted.empty() && modified.empty()) {
return Status::Success;
}
infos_.swap(new_infos);
UpdateDependencyGraph(added, deleted, modified);
for (const auto& name : deleted) {
model_life_cycle_->AsyncUnload(name);
}
// model loading / unloading error will be printed but ignored
LoadModelByDependency();
return Status::Success;
}
std::map<std::string, Status>
ModelRepositoryManager::LoadModelByDependency()
{
std::map<std::string, Status> res;
struct ModelState {
ModelState(DependencyNode* node) : node_(node), status_(Status::Success) {}
DependencyNode* node_;
Status status_;
std::promise<void> ready_;
};
NodeSet loaded_models;
auto set_pair = ModelsToLoadUnload(loaded_models);
// Loop until all model are loaded / unloaded
while ((!set_pair.first.empty()) || (!set_pair.second.empty())) {
loaded_models.clear();
// Unload invalid models first
for (auto& invalid_model : set_pair.second) {
model_life_cycle_->AsyncUnload(invalid_model->model_name_);
LOG_ERROR << invalid_model->status_.AsString();
invalid_model->loaded_versions_ = std::set<int64_t>();
loaded_models.emplace(invalid_model);
}
// load valid models and wait for load results
std::vector<std::unique_ptr<ModelState>> model_states;
for (auto& valid_model : set_pair.first) {
model_states.emplace_back(new ModelState(valid_model));
auto model_state = model_states.back().get();
const auto itr = infos_.find(valid_model->model_name_);
auto status = model_life_cycle_->AsyncLoad(
valid_model->model_name_, itr->second->model_path_,
valid_model->model_config_, itr->second->is_config_provided_,
itr->second->agent_model_list_, [model_state](Status load_status) {
model_state->status_ = load_status;
model_state->ready_.set_value();
});
if (!status.IsOk()) {
model_state->status_ = status;
model_state->ready_.set_value();
LOG_ERROR << "failed to load model '" << valid_model->model_name_
<< "': " << status.Message();
}
loaded_models.emplace(valid_model);
}
for (auto& model_state : model_states) {
model_state->ready_.get_future().wait();
res[model_state->node_->model_name_] = model_state->status_;
const auto version_state =
model_life_cycle_->VersionStates(model_state->node_->model_name_);
model_state->node_->loaded_versions_.clear();
for (const auto& vs : version_state) {
if (vs.second.first == ModelReadyState::READY) {
model_state->node_->loaded_versions_.emplace(vs.first);
}
}
// If the model failed to load, should revert the timestamp to
// ensure the next load request will attempt to load the model again
// for operation consistency.
if (!model_state->status_.IsOk()) {
auto& model_info = infos_.find(model_state->node_->model_name_)->second;
model_info->mtime_nsec_ = model_info->prev_mtime_ns_;
}
}
set_pair = ModelsToLoadUnload(loaded_models);
}
// Clear temporary stored agent model list after all loads are triggerred
for (auto& info : infos_) {
info.second->agent_model_list_.reset();
}
return res;
}
Status
ModelRepositoryManager::LoadUnloadModel(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
const ActionType type, const bool unload_dependents)
{
if (!model_control_enabled_) {
return Status(
Status::Code::UNAVAILABLE,
"explicit model load / unload is not allowed if polling is enabled");
}
if (models.size() > 1) {
return Status(
Status::Code::UNSUPPORTED,
"explicit load / unload multiple models is not currently supported");
}
// Serialize all operations that change model state
std::lock_guard<std::mutex> lock(poll_mu_);
bool polled = true;
RETURN_IF_ERROR(LoadUnloadModels(models, type, unload_dependents, &polled));
// Check if model is loaded / unloaded properly
const auto& model_name = models.begin()->first;
if (!polled) {
return Status(
Status::Code::INTERNAL, "failed to load '" + model_name +
"', failed to poll from model repository");
}
const auto version_states = model_life_cycle_->VersionStates(model_name);
if (type == ActionType::LOAD) {
if (version_states.empty()) {
return Status(
Status::Code::INTERNAL,
"failed to load '" + model_name + "', no version is available");
}
auto it = infos_.find(model_name);
if (it == infos_.end()) {
return Status(
Status::Code::INTERNAL,
"failed to load '" + model_name +
"', failed to poll from model repository");
}
} else {
std::string ready_version_str;
for (const auto& version_state : version_states) {
if (version_state.second.first == ModelReadyState::READY) {
ready_version_str += std::to_string(version_state.first);
ready_version_str += ",";
}
}
if (!ready_version_str.empty()) {
ready_version_str.pop_back();
return Status(
Status::Code::INTERNAL,
"failed to unload '" + model_name +
"', versions that are still available: " + ready_version_str);
}
}
return Status::Success;
}
Status
ModelRepositoryManager::LoadUnloadModels(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
const ActionType type, const bool unload_dependents,
bool* all_models_polled)
{
auto status = Status::Success;
*all_models_polled = true;
// Update ModelInfo related to file system accordingly
std::set<std::string> added, deleted, modified, unmodified;
{
if (type == ActionType::UNLOAD) {
for (const auto& model : models) {
deleted.insert(model.first);
}
}
// ActionType::LOAD and in model control mode
else {
std::set<std::string> checked_models;
auto current_models = models;
for (const auto& model : models) {
checked_models.emplace(model.first);
}
ModelInfoMap new_infos;
#ifdef TRITON_ENABLE_ENSEMBLE
bool first_iteration = true;
#endif // TRITON_ENABLE_ENSEMBLE
while (!current_models.empty()) {
bool polled = true;
RETURN_IF_ERROR(Poll(
current_models, &added, &deleted, &modified, &unmodified,
&new_infos, &polled));
*all_models_polled &= polled;
// More models should be polled if the polled models are ensembles
std::unordered_map<std::string, std::vector<const InferenceParameter*>>
next_models;
#ifdef TRITON_ENABLE_ENSEMBLE
for (const auto& model : current_models) {
auto it = new_infos.find(model.first);
// Some models may be marked as deleted and not in 'new_infos'
if (it != new_infos.end()) {
it->second->explicitly_load_ = first_iteration;
const auto& config = it->second->model_config_;
if (config.has_ensemble_scheduling()) {
for (const auto& step : config.ensemble_scheduling().step()) {
bool need_poll =
checked_models.emplace(step.model_name()).second;
if (need_poll) {
next_models[step.model_name()];
}
}
}
}
}
first_iteration = false;
#endif // TRITON_ENABLE_ENSEMBLE
current_models.swap(next_models);
}
// Only update the infos when all validation is completed
for (const auto& model_name : added) {
auto nitr = new_infos.find(model_name);
infos_.emplace(model_name, std::move(nitr->second));
}
for (const auto& model_name : modified) {
auto nitr = new_infos.find(model_name);
auto itr = infos_.find(model_name);
itr->second = std::move(nitr->second);
}
}
}
std::set<std::string> deleted_dependents;
// Update dependency graph and load
UpdateDependencyGraph(
added, deleted, modified,
unload_dependents ? &deleted_dependents : nullptr);
// The models are in 'deleted' either when they are asked to be unloaded or
// they are not found / are duplicated across all model repositories.
// In all cases, should unload them and remove from 'infos_' explicitly.
for (const auto& name : (unload_dependents ? deleted_dependents : deleted)) {
infos_.erase(name);
model_life_cycle_->AsyncUnload(name);
}
// load / unload the models affected, and check the load status of
// the requested models
const auto& load_status = LoadModelByDependency();
if (status.IsOk() && (type == ActionType::LOAD)) {
std::string load_error_message = "";
for (const auto& model : models) {
auto it = load_status.find(model.first);
// If 'model.first' not in load status, it means the (re-)load is not
// necessary because there is no change in the model's directory
if ((it != load_status.end()) && !it->second.IsOk()) {
load_error_message +=
("load failed for model '" + model.first +
"': " + it->second.Message() + "\n");
}
}
if (!load_error_message.empty()) {
status = Status(Status::Code::INVALID_ARG, load_error_message);
}
}
return status;
}
Status
ModelRepositoryManager::UnloadAllModels()
{
Status status;
for (const auto& name_info : infos_) {
Status unload_status = model_life_cycle_->AsyncUnload(name_info.first);
if (!unload_status.IsOk()) {
status = Status(
unload_status.ErrorCode(),
"Failed to gracefully unload models: " + unload_status.Message());
}
}
return Status::Success;
}
Status
ModelRepositoryManager::StopAllModels()
{
return model_life_cycle_->StopAllModels();
}
const std::set<std::tuple<std::string, int64_t, size_t>>
ModelRepositoryManager::InflightStatus()
{
return model_life_cycle_->InflightStatus();
}
const ModelStateMap
ModelRepositoryManager::LiveModelStates(bool strict_readiness)
{
return model_life_cycle_->LiveModelStates(strict_readiness);
}
const ModelStateMap
ModelRepositoryManager::ModelStates()
{
return model_life_cycle_->ModelStates();
}
const VersionStateMap
ModelRepositoryManager::VersionStates(const std::string& model_name)
{
return model_life_cycle_->VersionStates(model_name);
}
Status
ModelRepositoryManager::ModelState(
const std::string& model_name, const int64_t model_version,
ModelReadyState* state)
{
return model_life_cycle_->ModelState(model_name, model_version, state);
}
Status
ModelRepositoryManager::RepositoryIndex(
const bool ready_only, std::vector<ModelIndex>* index)
{
std::set<std::string> seen_models;
std::set<std::string> duplicate_models;
for (const auto& repository_path : repository_paths_) {
// For any mapped models in this repository, save the mapping
// from their subdirectory name to model name.
std::map<std::string, std::string> models_in_repo;
for (const auto& mapping_it : model_mappings_) {
if (mapping_it.second.first == repository_path) {
models_in_repo.emplace(
BaseName(mapping_it.second.second), mapping_it.first);
}
}
std::set<std::string> subdirs;
RETURN_IF_ERROR(GetDirectorySubdirs(repository_path, &subdirs));
for (const auto& subdir : subdirs) {
auto model = subdir;
auto model_it = models_in_repo.find(subdir);
if (model_it != models_in_repo.end()) {
model = model_it->second;
}
if (seen_models.find(model) != seen_models.end()) {
duplicate_models.insert(model);
}
seen_models.insert(model);
}
}
ModelStateMap states = ModelStates();
for (const auto& model : seen_models) {
// If the same model appears in multiple repostories then show it
// as unavailable since duplicate models are not allowed to load.
if (duplicate_models.find(model) != duplicate_models.end()) {
index->emplace_back(
model, -1 /* version */, ModelReadyState::UNAVAILABLE,
MODEL_READY_REASON_DUPLICATE);
continue;
}
// If there is any version/state/reason associated with the model
// then include that in the index.
auto sitr = states.find(model);
if (sitr == states.end()) {
if (!ready_only) {
index->emplace_back(model);
}
} else {
for (const auto& pr : sitr->second) {
if (!ready_only || (pr.second.first == ModelReadyState::READY)) {
index->emplace_back(
model, pr.first, pr.second.first, pr.second.second);
}
}
}
}
return Status::Success;
}
Status
ModelRepositoryManager::GetModel(
const std::string& model_name, const int64_t model_version,
std::shared_ptr<Model>* model)
{
Status status = model_life_cycle_->GetModel(model_name, model_version, model);
if (!status.IsOk()) {
model->reset();
status = Status(
status.ErrorCode(), "Request for unknown model: " + status.Message());
}
return status;
}
Status
ModelRepositoryManager::Poll(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models,
std::set<std::string>* added, std::set<std::string>* deleted,
std::set<std::string>* modified, std::set<std::string>* unmodified,
ModelInfoMap* updated_infos, bool* all_models_polled)
{
*all_models_polled = true;
// empty path is the special case to indicate the model should be loaded
// from override file content in 'models'.
std::map<std::string, std::string> model_to_path;
// If no model is specified, poll all models in all model repositories.
// Otherwise, only poll the specified models
if (models.empty()) {
std::set<std::string> duplicated_models;
for (const auto& repository_path : repository_paths_) {
std::set<std::string> subdirs;
Status status = GetDirectorySubdirs(repository_path, &subdirs);
if (!status.IsOk()) {
LOG_ERROR << "failed to poll model repository '" << repository_path
<< "': " << status.Message();
*all_models_polled = false;
} else {
for (const auto& subdir : subdirs) {
if (!model_to_path
.emplace(subdir, JoinPath({repository_path, subdir}))
.second) {
duplicated_models.insert(subdir);
*all_models_polled = false;
}
}
}
}
// If the model is not unique, mark as deleted to unload it
for (const auto& model : duplicated_models) {
model_to_path.erase(model);
deleted->insert(model);
LOG_ERROR << "failed to poll model '" << model
<< "': not unique across all model repositories";
}
}
// If models are specified, this is explicit model control mode.
else {
for (const auto& model : models) {
// Skip repository polling if override model files
if (ModelDirectoryOverride(model.second)) {
model_to_path.emplace(model.first, "");
continue;
}
// Check model mapping first to see if matching model to load.
bool exists = false;
auto model_it = model_mappings_.find(model.first);
if (model_it != model_mappings_.end()) {
bool exists_in_this_repo = false;
auto full_path = model_it->second.second;
Status status = FileExists(full_path, &exists_in_this_repo);
if (!status.IsOk()) {
LOG_ERROR << "failed to poll mapped path '" << full_path
<< "' for model '" << model.first
<< "': " << status.Message();
*all_models_polled = false;
}
if (exists_in_this_repo) {
model_to_path.emplace(model.first, model_it->second.second);
exists = true;
} else {
LOG_ERROR << "mapped path '" << full_path
<< "' does not exist for model '" << model.first << "'";
exists = false;
}
} else {
for (const auto repository_path : repository_paths_) {
bool exists_in_this_repo = false;
const auto full_path = JoinPath({repository_path, model.first});
Status status = FileExists(full_path, &exists_in_this_repo);
if (!status.IsOk()) {
LOG_ERROR << "failed to poll model repository '" << repository_path
<< "' for model '" << model.first
<< "': " << status.Message();
*all_models_polled = false;
} else if (exists_in_this_repo) {
// Check to make sure this directory is not mapped.
// If mapped, continue to next repository path.
bool mapped = false;
for (auto const& mapping : model_mappings_) {
if (mapping.second.second == full_path) {
mapped = true;
break;
}
}
if (mapped) {
continue;
}
auto res = model_to_path.emplace(
model.first, JoinPath({repository_path, model.first}));
if (res.second) {
exists = true;
} else {
exists = false;
model_to_path.erase(res.first);
LOG_ERROR << "failed to poll model '" << model.first
<< "': not unique across all model repositories";
break;
}
}
}
}
// For an explicitly specified model that doesn't exist, we don't mark it
// as deleted, we simply mark that we couldn't poll all models.
if (!exists) {
*all_models_polled = false;
}
}
}
// Poll each of the models. If error happens during polling the model,
// its state will fallback to the state before the polling.
for (const auto& pair : model_to_path) {
std::unique_ptr<ModelInfo> model_info;
const auto& mit = models.find(pair.first);
static std::vector<const InferenceParameter*> empty_params;
auto status = InitializeModelInfo(
pair.first, pair.second,
((mit == models.end()) ? empty_params : mit->second), &model_info);
const auto& iitr = infos_.find(pair.first);
const bool invalid_add = (!status.IsOk()) && (iitr == infos_.end());
if (!invalid_add) {
const auto& ret = updated_infos->emplace(pair.first, nullptr);
if (!ret.second) {
return Status(
Status::Code::ALREADY_EXISTS,
"unexpected model info for model '" + pair.first + "'");
}
// Classify load state and set updated info
if (model_info == nullptr) {
ret.first->second.reset(new ModelInfo(*iitr->second));
unmodified->insert(pair.first);
} else {
ret.first->second = std::move(model_info);
if (iitr != infos_.end()) {
modified->insert(pair.first);
} else {
added->insert(pair.first);
}
}
}
if (!status.IsOk()) {
LOG_ERROR << "Poll failed for model directory '" << pair.first
<< "': " << status.Message();
*all_models_polled = false;
}
}
return Status::Success;
}
bool
ModelRepositoryManager::ModelDirectoryOverride(
const std::vector<const InferenceParameter*>& model_params)
{
for (const auto& param : model_params) {
if (param->Name().rfind(file_prefix, 0) == 0) {
// param name starts with prefix if user provides override file
return true;
}
}
return false;
}
Status
ModelRepositoryManager::InitializeModelInfo(
const std::string& name, const std::string& path,
const std::vector<const InferenceParameter*>& params,
std::unique_ptr<ModelInfo>* info)
{
std::unique_ptr<ModelInfo> linfo(new ModelInfo());
linfo->model_path_ = path;
bool unmodified = false;
const auto iitr = infos_.find(name);
// Set 'prev_mtime_ns_' if there is existing ModelInfo
if (iitr != infos_.end()) {
linfo->prev_mtime_ns_ = iitr->second->mtime_nsec_;
} else {
linfo->prev_mtime_ns_ = 0;
}
// Set 'mtime_nsec_' and override 'model_path_' if current path is empty
// (file override is specified)
if (linfo->model_path_.empty()) {
// Need to localize the override files, use repo agent to manage
// the lifecycle of the localized files
std::shared_ptr<TritonRepoAgent> localize_agent(new LocalizeRepoAgent());
std::unique_ptr<TritonRepoAgentModel> localize_agent_model;
RETURN_IF_ERROR(TritonRepoAgentModel::Create(
TRITONREPOAGENT_ARTIFACT_FILESYSTEM, "", inference::ModelConfig(),
localize_agent, {}, &localize_agent_model));
// Set agent model state so the repo agent can access the encoded files
// Using const_cast here but we are safe as the RepoAgent will not
// modify the state
localize_agent_model->SetState(
const_cast<void*>(reinterpret_cast<const void*>(&params)));
RETURN_IF_ERROR(
localize_agent_model->InvokeAgent(TRITONREPOAGENT_ACTION_LOAD));
const char* location;
TRITONREPOAGENT_ArtifactType type;
RETURN_IF_ERROR(localize_agent_model->Location(&type, &location));
// For file override, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo->mtime_nsec_ = 0;
linfo->model_path_ = location;
linfo->agent_model_list_.reset(new TritonRepoAgentModelList());
linfo->agent_model_list_->AddAgentModel(std::move(localize_agent_model));
} else {
if (iitr == infos_.end()) {
linfo->mtime_nsec_ = GetModifiedTime(std::string(linfo->model_path_));
} else {
// Check the current timestamps to determine if model actually has been
// modified
linfo->mtime_nsec_ = linfo->prev_mtime_ns_;
unmodified =
!IsModified(std::string(linfo->model_path_), &linfo->mtime_nsec_);
}
}
// Set 'model_config_'
bool parsed_config = false;
// Check if there is config override
for (const auto& override_parameter : params) {
if ((override_parameter->Name() == "config") &&
(override_parameter->Type() == TRITONSERVER_PARAMETER_STRING)) {
// When override happens, set 'mtime_nsec_' to minimum value so that
// the next load without override will trigger re-load to undo
// the override while the local files may still be unchanged.
linfo->mtime_nsec_ = 0;
unmodified = false;
const std::string& override_config = override_parameter->ValueString();
auto err = JsonToModelConfig(
override_config, 1 /* config_version */, &linfo->model_config_);
if (!err.IsOk()) {
return Status(
Status::Code::INVALID_ARG,
"Invalid config override: " + std::string(err.Message()));
}
parsed_config = true;
break;
} else if (override_parameter->Name().rfind(file_prefix, 0) != 0) {
return Status(
Status::Code::INVALID_ARG,
"Unrecognized load parameter '" + override_parameter->Name() +
"' with type '" +
TRITONSERVER_ParameterTypeString(override_parameter->Type()) +
"'");
}
}
// Polling model is considered unmodified by this point and can be returned
// with info == nullptr
if (unmodified) {
return Status::Success;
}
// Create the associated repo agent models when a model is to be loaded,
// this must be done before normalizing model config as agents might
// redirect to use the model config at a different location
if (!parsed_config) {
const auto config_path = JoinPath({linfo->model_path_, kModelConfigPbTxt});
bool model_config_exists = false;
RETURN_IF_ERROR(FileExists(config_path, &model_config_exists));
// model config can be missing if auto fill is set
if (autofill_ && !model_config_exists) {
linfo->model_config_.Clear();
} else {
RETURN_IF_ERROR(ReadTextProto(config_path, &linfo->model_config_));
parsed_config = true;
}
}
if (parsed_config) {
RETURN_IF_ERROR(CreateAgentModelListWithLoadAction(
linfo->model_config_, linfo->model_path_, &linfo->agent_model_list_));
if (linfo->agent_model_list_ != nullptr) {
// Get the latest repository path
const char* location;
TRITONREPOAGENT_ArtifactType artifact_type;
RETURN_IF_ERROR(linfo->agent_model_list_->Back()->Location(
&artifact_type, &location));
auto latest_path = std::string(location);
linfo->model_path_ = latest_path;
}
}
linfo->is_config_provided_ = parsed_config;
// Try to automatically generate missing parts of the model
// configuration (autofill) that don't require model detail
RETURN_IF_ERROR(GetNormalizedModelConfig(
name, linfo->model_path_, min_compute_capability_,
&linfo->model_config_));
// Note that the model inputs and outputs are not validated until
// the model model is intialized as they may not be auto-completed
// until model is intialized.
RETURN_IF_ERROR(
ValidateModelConfig(linfo->model_config_, min_compute_capability_));
if (!autofill_) {
RETURN_IF_ERROR(ValidateModelIOConfig(linfo->model_config_));
}
// If the model is mapped, update its config name based on the
// mapping.
if (model_mappings_.find(name) != model_mappings_.end()) {
linfo->model_config_.set_name(name);
} else {
// If there is no model mapping, make sure the name of the model
// matches the name of the directory. This is a somewhat arbitrary
// requirement but seems like good practice to require it of the user.
// It also acts as a check to make sure we don't have two different
// models with the same name.
if (linfo->model_config_.name() != name) {
return Status(
Status::Code::INVALID_ARG,
"unexpected directory name '" + name + "' for model '" +
linfo->model_config_.name() +
"', directory name must equal model name");
}
}
*info = std::move(linfo);
return Status::Success;
}
Status
ModelRepositoryManager::UpdateDependencyGraph(
const std::set<std::string>& added, const std::set<std::string>& deleted,
const std::set<std::string>& modified,
std::set<std::string>* deleted_dependents)
{
// update dependency graph, if the state of a node is changed, all its
// downstreams will be affected
// deleted, drop from dependency_graph, add to missing_nodes if downstreams is
// not empty affected_nodes are all ensembles as only ensembles are depending
// on other models
std::set<DependencyNode*> affected_nodes;
std::set<DependencyNode*> updated_nodes;
std::set<std::string> current_deleted = deleted;
while (!current_deleted.empty()) {
std::set<std::string> next_deleted;
for (const auto& model_name : current_deleted) {
auto it = dependency_graph_.find(model_name);
if (it != dependency_graph_.end()) {
// remove this node from its upstreams
for (auto& upstream : it->second->upstreams_) {
upstream.first->downstreams_.erase(it->second.get());
// Check if the upstream should be removed as well
if ((deleted_dependents != nullptr) &&
(upstream.first->downstreams_.empty()) &&
(!upstream.first->explicitly_load_)) {
next_deleted.emplace(upstream.first->model_name_);
}
}
it->second->upstreams_.clear();
if (!it->second->downstreams_.empty()) {
UncheckDownstream(&it->second->downstreams_, &affected_nodes);
// mark this node as missing upstream in its downstreams
for (auto& downstream : it->second->downstreams_) {
downstream->missing_upstreams_.emplace(it->second.get());
}
missing_nodes_.emplace(
std::make_pair(model_name, std::move(it->second)));
}
// Make sure deleted node will not be in affected nodes
affected_nodes.erase(it->second.get());
dependency_graph_.erase(it);
}
if (deleted_dependents != nullptr) {
deleted_dependents->emplace(model_name);
}
}
current_deleted.swap(next_deleted);
}
// modified, invalidate (uncheck) all downstreams
for (const auto& model_name : modified) {
auto it = dependency_graph_.find(model_name);
if (it != dependency_graph_.end()) {
UncheckDownstream(&it->second->downstreams_, &affected_nodes);
ModelInfo* info = nullptr;
GetModelInfo(model_name, &info);
it->second->model_config_ = info->model_config_;
it->second->explicitly_load_ = info->explicitly_load_;
// remove this node from its upstream node
for (auto& upstream : it->second->upstreams_) {
upstream.first->downstreams_.erase(it->second.get());
}
it->second->upstreams_.clear();
it->second->checked_ = false;
it->second->status_ = Status::Success;
updated_nodes.emplace(it->second.get());
}
}
// added, add to dependency_graph, if in missing_node, invalidate (uncheck)
// and associate all downstreams, remove from missing_node
for (const auto& model_name : added) {
std::unique_ptr<DependencyNode> added_node;
auto it = missing_nodes_.find(model_name);
if (it != missing_nodes_.end()) {
UncheckDownstream(&it->second->downstreams_, &affected_nodes);
// remove this node from missing upstream node in its downstream nodes
for (auto& downstream : it->second->downstreams_) {
downstream->missing_upstreams_.erase(it->second.get());
}
it->second->checked_ = false;
added_node = std::move(it->second);
missing_nodes_.erase(it);
} else {
// Right now, nothing is going to be filled until validation
added_node.reset(new DependencyNode(model_name));
}
ModelInfo* info = nullptr;
GetModelInfo(model_name, &info);
added_node->model_config_ = info->model_config_;
added_node->explicitly_load_ = info->explicitly_load_;
updated_nodes.emplace(added_node.get());
dependency_graph_.emplace(
std::make_pair(model_name, std::move(added_node)));
}
auto& affected_ensembles = affected_nodes;
for (auto& updated_node : updated_nodes) {
bool is_ensemble = ConnectDependencyGraph(updated_node);
if (is_ensemble) {
affected_ensembles.emplace(updated_node);
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// After the dependency graph is updated, check ensemble dependencies
for (auto& ensemble : affected_ensembles) {
if (ensemble->status_.IsOk()) {
if (!ensemble->missing_upstreams_.empty()) {
std::string name_list;
for (auto it = ensemble->missing_upstreams_.begin();
it != ensemble->missing_upstreams_.end(); it++) {
if (it != ensemble->missing_upstreams_.begin()) {
name_list += ", ";
}
name_list += (*it)->model_name_;
}
ensemble->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble " + ensemble->model_name_ +
" contains models that are not available: " + name_list);
} else {
ensemble->status_ = CircularcyCheck(ensemble, ensemble);
}
}
}
#endif // TRITON_ENABLE_ENSEMBLE
return Status::Success;
}
Status
ModelRepositoryManager::RegisterModelRepository(
const std::string& repository,
const std::unordered_map<std::string, std::string>& model_mapping)
{
if (!model_control_enabled_) {
return Status(
Status::Code::UNSUPPORTED,
"repository registration is not allowed if model control mode is not "
"EXPLICIT");
}
bool is_directory = false;
auto status = IsDirectory(repository, &is_directory);
if (!status.IsOk() || !is_directory) {
return Status(
Status::Code::INVALID_ARG, (std::string("failed to register '") +
repository + "', repository not found")
.c_str());
}
{
// Serialize all operations that change model state
std::lock_guard<std::mutex> lock(poll_mu_);
// Check repository and mapped models do not yet exist.
if (repository_paths_.find(repository) != repository_paths_.end()) {
return Status(
Status::Code::ALREADY_EXISTS,
"model repository '" + repository + "' has already been registered");
}
for (const auto& mapping : model_mapping) {
if (model_mappings_.find(mapping.first) != model_mappings_.end()) {
return Status(
Status::Code::ALREADY_EXISTS,
(std::string("failed to register '") + mapping.first +
"', there is a conflicting mapping for '" +
std::string(mapping.first) + "'")
.c_str());
}
}
repository_paths_.emplace(repository);
for (const auto& mapping : model_mapping) {
model_mappings_.emplace(
mapping.first,
std::make_pair(repository, JoinPath({repository, mapping.second})));
}
}
LOG_INFO << "Model repository registered: " << repository;
return Status::Success;
}
Status
ModelRepositoryManager::UnregisterModelRepository(const std::string& repository)
{
if (!model_control_enabled_) {
return Status(
Status::Code::UNSUPPORTED,
"repository unregistration is not allowed if model control mode is not "
"EXPLICIT");
}
{
std::lock_guard<std::mutex> lock(poll_mu_);
if (repository_paths_.erase(repository) != 1) {
return Status(
Status::Code::INVALID_ARG,
"failed to unregister '" + repository + "', repository not found");
}
std::set<std::string> models_to_delete;
for (auto const& mapping : model_mappings_) {
if (mapping.second.first == repository) {
models_to_delete.insert(mapping.first);
}
}
for (auto const& model : models_to_delete) {
model_mappings_.erase(model);
}
}
LOG_INFO << "Model repository unregistered: " << repository;
return Status::Success;
}
Status
ModelRepositoryManager::CircularcyCheck(
DependencyNode* current_node, const DependencyNode* start_node)
{
for (auto& downstream : current_node->downstreams_) {
if (downstream->model_name_ == start_node->model_name_) {
return Status(
Status::Code::INVALID_ARG,
"circular dependency between ensembles: " + start_node->model_name_ +
" -> ... -> " + current_node->model_name_ + " -> " +
start_node->model_name_);
} else {
const auto status = CircularcyCheck(downstream, start_node);
if (!status.IsOk() && current_node->status_.IsOk()) {
current_node->status_ = status;
return status;
}
}
}
return Status::Success;
}
void
ModelRepositoryManager::UncheckDownstream(
NodeSet* downstreams, NodeSet* updated_nodes)
{
// Mark downstream nodes as unchecked recursively
for (auto& node : *downstreams) {
if (node->checked_) {
node->checked_ = false;
node->status_ = Status::Success;
UncheckDownstream(&node->downstreams_, updated_nodes);
updated_nodes->emplace(node);
}
}
}
bool
ModelRepositoryManager::ConnectDependencyGraph(DependencyNode* updated_node)
{
// Check the node's model config to determine if it depends on other models
// and if those models are present
updated_node->upstreams_.clear();
updated_node->missing_upstreams_.clear();
if (updated_node->model_config_.has_ensemble_scheduling()) {
for (const auto& step :
updated_node->model_config_.ensemble_scheduling().step()) {
DependencyNode* upstream_node = nullptr;
const auto& model_name = step.model_name();
auto dit = dependency_graph_.find(model_name);
if (dit == dependency_graph_.end()) {
auto mit = missing_nodes_.find(model_name);
if (mit == missing_nodes_.end()) {
std::unique_ptr<DependencyNode> node(new DependencyNode(model_name));
updated_node->missing_upstreams_.emplace(node.get());
mit = missing_nodes_.emplace(model_name, std::move(node)).first;
}
// Add the node to missing node's downstream so that when the missing
// node is added, the downstreams can be found easily.
mit->second->downstreams_.emplace(updated_node);
upstream_node = mit->second.get();
} else {
dit->second->downstreams_.emplace(updated_node);
upstream_node = dit->second.get();
}
auto res = updated_node->upstreams_.emplace(
upstream_node, std::set<int64_t>({step.model_version()}));
// If map insertion doesn't happen, the same model is required in
// different step, insert the version to existing required version set.
if (!res.second) {
res.first->second.insert(step.model_version());
}
}
return true;
}
return false;
}
Status
ModelRepositoryManager::GetModelInfo(
const std::string& name, ModelInfo** model_info)
{
const auto itr = infos_.find(name);
if (itr == infos_.end()) {
return Status(
Status::Code::NOT_FOUND, "no configuration for model '" + name + "'");
}
*model_info = itr->second.get();
return Status::Success;
}
std::pair<ModelRepositoryManager::NodeSet, ModelRepositoryManager::NodeSet>
ModelRepositoryManager::ModelsToLoadUnload(const NodeSet& loaded_models)
{
// <valid model set, invalid model set>
std::pair<NodeSet, NodeSet> res;
// first call to this function
if (loaded_models.empty()) {
for (auto& pair : dependency_graph_) {
auto node = pair.second.get();
// only care about nodes that are affected by the update
if (!node->checked_) {
if (CheckNode(node)) {
if (node->status_.IsOk()) {
res.first.emplace(node);
} else {
res.second.emplace(node);
}
}
}
}
} else {
for (const auto& model : loaded_models) {
for (auto node : model->downstreams_) {
// only care about nodes that are affected by the update
if (!node->checked_) {
if (CheckNode(node)) {
if (node->status_.IsOk()) {
res.first.emplace(node);
} else {
res.second.emplace(node);
}
}
}
}
}
}
for (auto& node : res.first) {
node->checked_ = true;
}
for (auto& node : res.second) {
node->checked_ = true;
}
return res;
}
bool
ModelRepositoryManager::CheckNode(DependencyNode* node)
{
bool node_ready = true;
// if the node is in invalid status, mark as ready as we know
// it should not be loaded
if (node->status_.IsOk()) {
for (auto& upstream : node->upstreams_) {
if (!upstream.first->checked_) {
node_ready = false;
break;
}
if (!upstream.first->status_.IsOk()) {
node->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble '" + node->model_name_ + "' depends on '" +
upstream.first->model_name_ + "' which is not valid");
} else if (upstream.first->loaded_versions_.empty()) {
node->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble '" + node->model_name_ + "' depends on '" +
upstream.first->model_name_ + "' which has no loaded version");
} else {
for (const auto& required_version : upstream.second) {
if (required_version == -1) {
continue;
}
auto it = upstream.first->loaded_versions_.find(required_version);
if (it == upstream.first->loaded_versions_.end()) {
node->status_ = Status(
Status::Code::INVALID_ARG,
"ensemble '" + node->model_name_ + "' depends on '" +
upstream.first->model_name_ + "' whose required version " +
std::to_string(required_version) + " is not loaded");
}
}
}
if (!node->status_.IsOk()) {
break;
}
}
#ifdef TRITON_ENABLE_ENSEMBLE
// Validate ensemble config if the node is ready. By this point, the
// depending models are loaded and their configs are completed
if (node_ready && node->status_.IsOk()) {
node->status_ = ValidateEnsembleConfig(this, node);
}
#endif // TRITON_ENABLE_ENSEMBLE
}
return node_ready;
}
}} // namespace triton::core
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment