Commit b30f3cdb authored by xiabo's avatar xiabo
Browse files

添加下载的代码

parent e38ee081
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "ensemble_model.h"
#include <stdint.h>
#include "constants.h"
#include "ensemble_scheduler.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
Status
EnsembleModel::Create(
InferenceServer* server, const std::string& path, const int64_t version,
const inference::ModelConfig& model_config, const bool is_config_provided,
const double min_compute_capability, std::unique_ptr<Model>* model)
{
// Create the ensemble model.
std::unique_ptr<EnsembleModel> local_model(
new EnsembleModel(min_compute_capability, path, version, model_config));
RETURN_IF_ERROR(local_model->Init(is_config_provided));
std::unique_ptr<Scheduler> scheduler;
RETURN_IF_ERROR(EnsembleScheduler::Create(
local_model->MutableStatsAggregator(), server, model_config, &scheduler));
RETURN_IF_ERROR(local_model->SetScheduler(std::move(scheduler)));
LOG_VERBOSE(1) << "ensemble model for " << local_model->Name() << std::endl;
*model = std::move(local_model);
return Status::Success;
}
std::ostream&
operator<<(std::ostream& out, const EnsembleModel& pb)
{
out << "name=" << pb.Name() << std::endl;
return out;
}
}} // namespace triton::core
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "model.h"
#include "model_config.pb.h"
#include "scheduler.h"
#include "status.h"
namespace triton { namespace core {
class InferenceServer;
class EnsembleModel : public Model {
public:
EnsembleModel(EnsembleModel&&) = default;
static Status Create(
InferenceServer* server, const std::string& path, const int64_t version,
const inference::ModelConfig& model_config, const bool is_config_provided,
const double min_compute_capability, std::unique_ptr<Model>* model);
private:
DISALLOW_COPY_AND_ASSIGN(EnsembleModel);
explicit EnsembleModel(
const double min_compute_capability, const std::string& model_dir,
const int64_t version, const inference::ModelConfig& config)
: Model(min_compute_capability, model_dir, version, config)
{
}
friend std::ostream& operator<<(std::ostream&, const EnsembleModel&);
};
std::ostream& operator<<(std::ostream& out, const EnsembleModel& pb);
}} // namespace triton::core
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_scheduler.h"
#include <mutex>
#include "cuda_utils.h"
#include "metrics.h"
#include "model.h"
#include "model_config_utils.h"
#include "server.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
namespace {
class EnsembleContext;
using IterationCount = size_t;
// Request tracker is passed as 'userp' in RequestRelease function and used
// to manage the lifecycle of the ensemble request
class RequestTracker {
public:
explicit RequestTracker(
std::unique_ptr<InferenceRequest>&& request, uint64_t compute_start_ns,
MetricModelReporter* metric_reporter,
InferenceStatsAggregator* stats_aggregator)
: inflight_request_counter_(1), request_(std::move(request)),
compute_start_ns_(compute_start_ns), metric_reporter_(metric_reporter),
stats_aggregator_(stats_aggregator), status_(Status::Success)
{
}
std::unique_ptr<InferenceRequest>& Request() { return request_; }
InferenceStatsAggregator& ContextStatsAggregator()
{
return context_stats_aggregator_;
}
void IncrementCounter()
{
std::lock_guard<std::mutex> lk(mtx_);
inflight_request_counter_++;
}
bool DecrementCounter()
{
std::lock_guard<std::mutex> lk(mtx_);
inflight_request_counter_--;
if (inflight_request_counter_ == 0) {
#ifdef TRITON_ENABLE_STATS
const auto& infer_stats = context_stats_aggregator_.ImmutableInferStats();
request_->ReportStatisticsWithDuration(
metric_reporter_, status_.IsOk(), compute_start_ns_,
infer_stats.compute_input_duration_ns_,
infer_stats.compute_infer_duration_ns_,
infer_stats.compute_output_duration_ns_);
if (status_.IsOk()) {
stats_aggregator_->UpdateInferBatchStatsWithDuration(
metric_reporter_, std::max(1U, request_->BatchSize()),
infer_stats.compute_input_duration_ns_,
infer_stats.compute_infer_duration_ns_,
infer_stats.compute_output_duration_ns_);
}
#endif
InferenceRequest::Release(
std::move(request_), TRITONSERVER_REQUEST_RELEASE_ALL);
}
return (inflight_request_counter_ == 0);
}
void SetStatus(const Status& status)
{
std::lock_guard<std::mutex> lk(mtx_);
status_ = status;
}
private:
std::mutex mtx_;
uint32_t inflight_request_counter_;
std::unique_ptr<InferenceRequest> request_;
uint64_t compute_start_ns_;
MetricModelReporter* metric_reporter_;
InferenceStatsAggregator* stats_aggregator_;
InferenceStatsAggregator context_stats_aggregator_;
Status status_;
};
// Step is used as 'userp' and keeps ensemble context alive
// until no more internal requests are inflight.
// Step contains metadata, and status for the
// internal infer request
struct Step {
Step(
size_t step_idx, const InferenceRequest::SequenceId& correlation_id,
uint32_t flags)
: correlation_id_(correlation_id), flags_(flags), response_flags_(0),
infer_status_(nullptr), step_idx_(step_idx)
{
}
std::shared_ptr<EnsembleContext> ctx_;
std::unique_ptr<InferenceRequest> request_;
InferenceRequest::SequenceId correlation_id_;
uint32_t flags_;
std::mutex output_mtx_;
// Different output map to avoid address conflict from different memory types
std::unordered_map<uintptr_t, std::shared_ptr<AllocatedMemory>>
cpu_output_map_;
std::unordered_map<
int64_t, std::unordered_map<uintptr_t, std::shared_ptr<AllocatedMemory>>>
gpu_output_map_;
std::set<std::pair<std::string, IterationCount>> updated_tensors_;
uint32_t response_flags_;
TRITONSERVER_Error* infer_status_;
size_t step_idx_;
};
struct TensorData {
struct Metadata {
Metadata() = default;
Metadata(
std::unique_ptr<InferenceRequest::Input>&& data, size_t reference_count)
: data_(std::move(data)), remaining_reference_count_(reference_count),
parameter_override_(false)
{
}
Metadata(
std::unique_ptr<InferenceRequest::Input>&& data, size_t reference_count,
const InferenceRequest::SequenceId& correlation_id, uint32_t flags)
: data_(std::move(data)), remaining_reference_count_(reference_count),
parameter_override_(true), correlation_id_(correlation_id),
flags_(flags)
{
}
std::unique_ptr<InferenceRequest::Input> data_;
size_t remaining_reference_count_;
bool parameter_override_;
InferenceRequest::SequenceId correlation_id_;
uint32_t flags_;
};
TensorData() = default;
TensorData(const size_t outgoing_steps_count)
: current_iteration_(0), outgoing_steps_count_(outgoing_steps_count),
batch_size_(0)
{
}
IterationCount AddTensor(std::unique_ptr<InferenceRequest::Input>&& tensor)
{
tensor_.emplace(
current_iteration_, Metadata(std::move(tensor), outgoing_steps_count_));
return current_iteration_++;
}
IterationCount AddTensor(
std::unique_ptr<InferenceRequest::Input>&& tensor,
const InferenceRequest::SequenceId& correlation_id, uint32_t flags)
{
tensor_.emplace(
current_iteration_,
Metadata(
std::move(tensor), outgoing_steps_count_, correlation_id, flags));
return current_iteration_++;
}
// Tensors associated with the particular ensemble tensor.
// A container is used to handle the decoupled case
// where variable number of tensors will be produced.
// map 'iteration count' to pair of <tensor, remaining outgoing count>
std::unordered_map<IterationCount, Metadata> tensor_;
size_t current_iteration_;
size_t outgoing_steps_count_;
// Ensemble may be configured to passing tensor between batching model and
// non-batching model as long as the full shapes match and storing the batch
// size of the generated tensor explicitly for checking and setting proper
// shape for the downstream model request.
size_t batch_size_;
};
// EnsembleContext maintains the state of the ensemble request
//
// Using static functions to take advantage of shared_ptr, a copy of the
// shared_ptr will be made when a step is scheduled and it will go out of
// scope after the step's callback is finished. The step's callback will
// schedule new steps if available and the last step will finish the ensemble
// request.
// So we don't have to maintian the context in scheduler as the shared_ptr
// will destroy the context for us if there are no "in-flight" steps.
class EnsembleContext {
public:
EnsembleContext(
MetricModelReporter* metric_reporter,
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
cudaStream_t stream);
// Perform transition on 'context' state given the information of
// 'completed_step'
static void Proceed(
const std::shared_ptr<EnsembleContext>& context,
const std::unique_ptr<Step>& completed_step = nullptr);
private:
static TRITONSERVER_Error* ResponseAlloc(
TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name,
size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type,
int64_t preferred_memory_type_id, void* userp, void** buffer,
void** buffer_userp, TRITONSERVER_MemoryType* allocated_memory_type,
int64_t* allocated_memory_type_id);
static TRITONSERVER_Error* ResponseRelease(
TRITONSERVER_ResponseAllocator* allocator, void* buffer,
void* buffer_userp, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id);
static TRITONSERVER_Error* OutputBufferQuery(
TRITONSERVER_ResponseAllocator* allocator, void* userp,
const char* tensor_name, size_t* byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id);
static void RequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags,
void* userp);
static void ResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags,
void* userp);
using StepList = std::vector<std::unique_ptr<Step>>;
using VersionMap = std::unordered_map<int64_t, std::shared_ptr<Model>>;
// Helper function to reshape the given tensor according to the
// config shape and batching info and its actual shape and batching info.
// Note that 'dims' will be in full shape as opposed to 'config_dims'.
// Return the dims after reshape.
std::vector<int64_t> ReshapeTensorDims(
const triton::common::DimsList& config_dims,
const bool config_allow_batching, const size_t tensor_batch_size,
const std::vector<int64_t>& dims);
// Return the list of step that becomes ready due to tensor update
// from 'completed_step'
Status PrepareSteps(
const std::unique_ptr<Step>& completed_step, StepList* steps);
// Prepare infer stats and call the inference server's function to process
// the infer requests specified in 'steps'
static void ScheduleSteps(
const std::shared_ptr<EnsembleContext>& context, StepList&& steps);
// Helper function that updates ensemble state given 'completed_step' and
// returns the list of updated tensors in 'updated_tensors'
Status UpdateEnsembleState(
const std::unique_ptr<Step>& completed_step,
std::set<std::pair<std::string, IterationCount>>* updated_tensors);
// Helper function that returns a list of 'steps' that should be run under
// current ensemble state. 'updated_tensors' is used so that we don't need to
// iterate all the tensors to determine which step can be run.
Status GetNextSteps(
const std::set<std::pair<std::string, IterationCount>>& updated_tensors,
StepList* steps);
// Helper function that completes the response of the ensemble request
Status FinishEnsemble(
std::unique_ptr<InferenceResponse>&& response = nullptr);
// Helper function that initialize the 'step' given the info at 'step_idx'.
// The 'step' will have proper request / response provider for the model
Status InitStep(
const size_t step_idx, const IterationCount iteration_count,
std::unique_ptr<Step>* step);
// Helper function that set the output of the ensemble request if it is ready
// and valid.
Status CheckAndSetEnsembleOutput(
const std::set<std::pair<std::string, IterationCount>>& updated_tensors,
std::unique_ptr<InferenceResponse>* response);
InferenceServer* is_;
EnsembleInfo* info_;
// All EnsembleContext will use the same CUDA stream managed by
// the ensemble scheduler
cudaStream_t stream_;
// Mutex to avoid concurrent call on 'PrepareSteps' where ensemble state
// are being modified
std::mutex mutex_;
size_t inflight_step_counter_;
// pointer that either points to 'pruned_tensor_to_step_' or to
// 'info_->tensor_to_step_' if all ensemble outputs are requested
std::unordered_map<std::string, std::set<size_t>>* tensor_to_step_;
std::unordered_map<std::string, std::set<size_t>> pruned_tensor_to_step_;
std::unordered_map<std::string, TensorData> tensor_data_;
// Handle to all models that may be used in the ensemble
std::unordered_map<std::string, VersionMap> handles_;
// Request specific information that obtained from ensemble request and
// should be applied to all internal requests
uint32_t flags_;
std::string request_id_;
InferenceRequest::SequenceId correlation_id_;
uint32_t priority_;
uint64_t timeout_;
// Objects related to the ensemble infer request
Status ensemble_status_;
RequestTracker* request_tracker_;
// The allocator that will be used to allocate buffers for the
// inference result tensors.
std::unique_ptr<
TRITONSERVER_ResponseAllocator,
decltype(&TRITONSERVER_ResponseAllocatorDelete)>
allocator_;
};
EnsembleContext::EnsembleContext(
MetricModelReporter* metric_reporter,
InferenceStatsAggregator* stats_aggregator, InferenceServer* is,
EnsembleInfo* info, std::unique_ptr<InferenceRequest>& request,
cudaStream_t stream)
: is_(is), info_(info), stream_(stream), inflight_step_counter_(0),
allocator_(nullptr, TRITONSERVER_ResponseAllocatorDelete)
{
uint64_t compute_start_ns = 0;
INFER_STATS_SET_TIMESTAMP(compute_start_ns);
request_tracker_ = new RequestTracker(
std::move(request), compute_start_ns, metric_reporter, stats_aggregator);
auto& lrequest = request_tracker_->Request();
// Obtain model handles of all models in ensemble request such that
// they have the same lifetime as the ensemble request to avoid unloading
// while the ensemble is executing.
for (const auto& step_info : info_->steps_) {
auto it = handles_.find(step_info.model_name_);
if (it == handles_.end()) {
it = handles_.emplace(std::make_pair(step_info.model_name_, VersionMap()))
.first;
}
auto ver_it = it->second.find(step_info.model_version_);
if (ver_it == it->second.end()) {
std::shared_ptr<Model> model = nullptr;
ensemble_status_ = is_->GetModel(
step_info.model_name_, step_info.model_version_, &model);
if (!ensemble_status_.IsOk()) {
break;
}
it->second.emplace(std::make_pair(step_info.model_version_, model));
}
}
// Prune ensemble first if not all outputs are requested
std::set<std::string> ignored_tensor;
for (const auto& ensemble_output : info_->ensemble_output_shape_) {
ignored_tensor.insert(ensemble_output.first);
}
for (const auto& requested_output : lrequest->ImmutableRequestedOutputs()) {
ignored_tensor.erase(requested_output);
}
if (ignored_tensor.empty()) {
tensor_to_step_ = &(info_->tensor_to_step_);
} else {
pruned_tensor_to_step_ = info_->tensor_to_step_;
tensor_to_step_ = &pruned_tensor_to_step_;
// Backward traversal
std::unordered_map<size_t, size_t> step_requested_output_count;
while (!ignored_tensor.empty()) {
std::set<std::string> new_ignored_tensor;
for (const auto& output : ignored_tensor) {
auto step_idx = info_->tensor_to_prev_step_[output];
auto& step = info_->steps_[step_idx];
auto it = step_requested_output_count.find(step_idx);
if (it == step_requested_output_count.end()) {
auto output_count = step.output_to_tensor_.size();
it =
step_requested_output_count.emplace(step_idx, output_count).first;
}
// If none of the outputs of the step is requested,
// then the step can be pruned
if (--it->second == 0) {
for (const auto& input : step.input_to_tensor_) {
auto& step_set = pruned_tensor_to_step_[input.second];
step_set.erase(step_idx);
// If all steps depend on a tensor are pruned,
// then the tensor can be ignored.
if (step_set.empty()) {
new_ignored_tensor.insert(input.second);
}
}
}
}
ignored_tensor.swap(new_ignored_tensor);
}
}
for (const auto& pair : *tensor_to_step_) {
const auto& requested_outputs = lrequest->ImmutableRequestedOutputs();
// For requested outputs, add 1 to outgoing count as the ensemble itself
// isn't counted as step.
if (requested_outputs.find(pair.first) != requested_outputs.end()) {
tensor_data_.emplace(pair.first, TensorData(pair.second.size() + 1));
} else {
tensor_data_.emplace(pair.first, TensorData(pair.second.size()));
}
}
if (ensemble_status_.IsOk()) {
request_id_ = lrequest->Id();
correlation_id_ = lrequest->CorrelationId();
flags_ = lrequest->Flags();
priority_ = lrequest->Priority();
timeout_ = lrequest->TimeoutMicroseconds();
for (const auto& pr : lrequest->ImmutableInputs()) {
const InferenceRequest::Input* input = pr.second;
auto it = tensor_data_.find(input->Name());
if (it != tensor_data_.end()) {
auto& tensor_data = it->second;
// Shape() represents reshaped value without batch dimension,
// thus need to fill it if necessary.
std::unique_ptr<InferenceRequest::Input> tensor;
if (lrequest->BatchSize() != 0) {
std::vector<int64_t> shape{lrequest->BatchSize()};
shape.insert(
shape.end(), input->Shape().begin(), input->Shape().end());
tensor.reset(new InferenceRequest::Input(
input->Name(), input->DType(), shape));
} else {
tensor.reset(new InferenceRequest::Input(
input->Name(), input->DType(), input->Shape()));
}
tensor->SetData(input->Data());
for (const auto& host_policy_data : input->HostPolicyData()) {
tensor->SetData(host_policy_data.first, host_policy_data.second);
}
tensor_data.AddTensor(std::move(tensor));
tensor_data.batch_size_ = lrequest->BatchSize();
} else {
ensemble_status_ = Status(
Status::Code::INVALID_ARG,
lrequest->LogRequest() + "unexpected input '" + input->Name() +
"' in request header that does not map to any ensemble inputs");
}
}
// Iterate the ensemble optional inputs and add empty tensor data entry
// if the input is not provided
for (const auto& name : info_->optional_inputs_) {
auto it = tensor_data_.find(name);
if ((it != tensor_data_.end()) && it->second.tensor_.empty()) {
it->second.AddTensor(nullptr);
it->second.batch_size_ = lrequest->BatchSize();
}
}
}
TRITONSERVER_ResponseAllocator* allocator;
TRITONSERVER_Error* err = TRITONSERVER_ResponseAllocatorNew(
&allocator, ResponseAlloc, ResponseRelease, nullptr /* start_fn */);
if (err == nullptr) {
err = TRITONSERVER_ResponseAllocatorSetQueryFunction(
allocator, OutputBufferQuery);
}
if (err != nullptr) {
ensemble_status_ = Status(
TritonCodeToStatusCode(TRITONSERVER_ErrorCode(err)),
TRITONSERVER_ErrorMessage(err));
TRITONSERVER_ErrorDelete(err);
} else {
allocator_.reset(allocator);
}
}
TRITONSERVER_Error*
EnsembleContext::ResponseAlloc(
TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name,
size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type,
int64_t preferred_memory_type_id, void* userp, void** buffer,
void** buffer_userp, TRITONSERVER_MemoryType* allocated_memory_type,
int64_t* allocated_memory_type_id)
{
*buffer = nullptr;
*buffer_userp = nullptr;
auto allocated_buffer = std::make_shared<AllocatedMemory>(
byte_size, preferred_memory_type, preferred_memory_type_id);
auto mutable_buffer = allocated_buffer->MutableBuffer(
allocated_memory_type, allocated_memory_type_id);
if ((mutable_buffer != nullptr) || (byte_size == 0)) {
if (byte_size != 0) {
*buffer = static_cast<void*>(mutable_buffer);
auto step = reinterpret_cast<Step*>(userp);
std::lock_guard<std::mutex> lk(step->output_mtx_);
if (*allocated_memory_type == TRITONSERVER_MEMORY_GPU) {
step->gpu_output_map_[*allocated_memory_type_id].emplace(
reinterpret_cast<uintptr_t>(*buffer), std::move(allocated_buffer));
} else {
step->cpu_output_map_.emplace(
reinterpret_cast<uintptr_t>(*buffer), std::move(allocated_buffer));
}
}
LOG_VERBOSE(1) << "Internal response allocation: " << tensor_name
<< ", size " << byte_size << ", addr " << *buffer
<< ", memory type " << *allocated_memory_type << ", type id "
<< *allocated_memory_type_id;
}
return nullptr; // Success
}
TRITONSERVER_Error*
EnsembleContext::ResponseRelease(
TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp,
size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
{
LOG_VERBOSE(1) << "Internal response release: "
<< "size " << byte_size << ", addr " << buffer;
// Don't do anything when releasing a buffer since ResponseAlloc
// passes the ownership of the data to ensemble context.
return nullptr; // Success
}
TRITONSERVER_Error*
EnsembleContext::OutputBufferQuery(
TRITONSERVER_ResponseAllocator* allocator, void* userp,
const char* tensor_name, size_t* byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id)
{
// Ensemble will always attempt to satisfy any output buffer request
return nullptr; // Success
}
void
EnsembleContext::RequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request),
"deleting ensemble inference request");
auto request_tracker = reinterpret_cast<RequestTracker*>(userp);
if (request_tracker->DecrementCounter()) {
delete request_tracker;
}
}
}
void
EnsembleContext::ResponseComplete(
TRITONSERVER_InferenceResponse* response, const uint32_t flags, void* userp)
{
auto step_ptr = std::unique_ptr<Step>(reinterpret_cast<Step*>(userp));
step_ptr->response_flags_ = flags;
if (response != nullptr) {
auto err = TRITONSERVER_InferenceResponseError(response);
uint32_t count;
bool parameter_override = false;
InferenceRequest::SequenceId correlation_id{0};
uint32_t flags = 0;
if (err == nullptr) {
err = TRITONSERVER_InferenceResponseParameterCount(response, &count);
if (err == nullptr) {
for (uint32_t idx = 0; idx < count; idx++) {
const char* name;
TRITONSERVER_ParameterType type;
const void* vvalue;
err = TRITONSERVER_InferenceResponseParameter(
response, idx, &name, &type, &vvalue);
if (err == nullptr) {
if (!strcmp(name, "sequence_id")) {
switch (type) {
case TRITONSERVER_PARAMETER_INT:
correlation_id = InferenceRequest::SequenceId(
*reinterpret_cast<const uint64_t*>(vvalue));
parameter_override = true;
break;
case TRITONSERVER_PARAMETER_STRING:
correlation_id = InferenceRequest::SequenceId(std::string(
*reinterpret_cast<const char* const*>(vvalue)));
parameter_override = true;
break;
default:
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"expected parameter 'sequence_id' to be "
"TRITONSERVER_PARAMETER_INT or "
"TRITONSERVER_PARAMETER_STRING");
}
} else if (!strcmp(name, "sequence_start")) {
if (type != TRITONSERVER_PARAMETER_BOOL) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"expect paremeter 'sequence_start' to be "
"TRITONSERVER_PARAMETER_BOOL");
} else {
if (*reinterpret_cast<const bool*>(vvalue)) {
flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_START;
}
parameter_override = true;
}
} else if (!strcmp(name, "sequence_end")) {
if (type != TRITONSERVER_PARAMETER_BOOL) {
err = TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INVALID_ARG,
"expect paremeter 'sequence_end' to be "
"TRITONSERVER_PARAMETER_BOOL");
} else {
if (*reinterpret_cast<const bool*>(vvalue)) {
flags |= TRITONSERVER_REQUEST_FLAG_SEQUENCE_END;
}
parameter_override = true;
}
}
}
}
}
}
if (err == nullptr) {
err = TRITONSERVER_InferenceResponseOutputCount(response, &count);
if (err == nullptr) {
std::lock_guard<std::mutex> lock(step_ptr->ctx_->mutex_);
auto& output_to_tensor =
step_ptr->ctx_->info_->steps_[step_ptr->step_idx_]
.output_to_tensor_;
for (uint32_t idx = 0; idx < count; idx++) {
const char* name;
TRITONSERVER_DataType datatype;
const int64_t* shape;
uint64_t dim_count;
const void* base;
size_t byte_size;
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
void* userp;
err = TRITONSERVER_InferenceResponseOutput(
response, idx, &name, &datatype, &shape, &dim_count, &base,
&byte_size, &memory_type, &memory_type_id, &userp);
if (err == nullptr) {
auto it = output_to_tensor.find(name);
if (it != output_to_tensor.end()) {
std::unique_ptr<InferenceRequest::Input> tensor(
new InferenceRequest::Input(
it->second, TritonToDataType(datatype), shape,
dim_count));
if (byte_size != 0) {
std::lock_guard<std::mutex> output_lk(step_ptr->output_mtx_);
if (memory_type == TRITONSERVER_MEMORY_GPU) {
auto& gpu_output_map =
step_ptr->gpu_output_map_[memory_type_id];
auto it =
gpu_output_map.find(reinterpret_cast<uintptr_t>(base));
tensor->SetData(std::move(it->second));
gpu_output_map.erase(it);
} else {
auto it = step_ptr->cpu_output_map_.find(
reinterpret_cast<uintptr_t>(base));
tensor->SetData(std::move(it->second));
step_ptr->cpu_output_map_.erase(it);
}
}
auto& tensor_data = step_ptr->ctx_->tensor_data_[it->second];
if (parameter_override) {
step_ptr->updated_tensors_.emplace(
it->second, tensor_data.AddTensor(
std::move(tensor), correlation_id, flags));
} else {
step_ptr->updated_tensors_.emplace(
it->second,
tensor_data.AddTensor(
std::move(tensor), step_ptr->correlation_id_,
step_ptr->flags_));
}
} else {
LOG_VERBOSE(1)
<< "in ensemble, an internal response header specified "
"output '"
<< name << "' that does not map to any ensemble tensors";
}
}
if (err != nullptr) {
break;
}
}
}
}
if (err != nullptr) {
step_ptr->infer_status_ = err;
}
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseDelete(response),
"deleting inference response");
}
EnsembleContext::Proceed(step_ptr->ctx_, step_ptr);
// Expecting more responses
if ((flags & TRITONSERVER_RESPONSE_COMPLETE_FINAL) == 0) {
step_ptr.release();
}
}
void
EnsembleContext::Proceed(
const std::shared_ptr<EnsembleContext>& context,
const std::unique_ptr<Step>& completed_step)
{
StepList ready_steps;
Status status = context->PrepareSteps(completed_step, &ready_steps);
if (status.IsOk()) {
ScheduleSteps(context, std::move(ready_steps));
}
}
Status
EnsembleContext::PrepareSteps(
const std::unique_ptr<Step>& completed_step, StepList* ready_steps)
{
{
std::lock_guard<std::mutex> lock(mutex_);
// Initialization error, ensemble status will be not ok since the beginning
if (completed_step == nullptr && !ensemble_status_.IsOk()) {
ensemble_status_ = FinishEnsemble();
}
if (ensemble_status_.IsOk()) {
StepList res;
std::set<std::pair<std::string, IterationCount>> updated_tensors;
ensemble_status_ = UpdateEnsembleState(completed_step, &updated_tensors);
if (ensemble_status_.IsOk()) {
ensemble_status_ = GetNextSteps(updated_tensors, ready_steps);
}
// Check and send ensemble response
if ((!ensemble_status_.IsOk()) || (inflight_step_counter_ == 0) ||
info_->is_decoupled_) {
std::unique_ptr<InferenceResponse> response;
if (ensemble_status_.IsOk()) {
ensemble_status_ =
CheckAndSetEnsembleOutput(updated_tensors, &response);
}
ensemble_status_ = FinishEnsemble(std::move(response));
}
}
return ensemble_status_;
}
}
Status
EnsembleContext::UpdateEnsembleState(
const std::unique_ptr<Step>& completed_step,
std::set<std::pair<std::string, IterationCount>>* updated_tensors)
{
updated_tensors->clear();
if (completed_step == nullptr) {
for (const auto& tensor_data : tensor_data_) {
if (!tensor_data.second.tensor_.empty()) {
updated_tensors->emplace(tensor_data.first, 0);
}
}
} else {
if (completed_step->response_flags_ &
TRITONSERVER_RESPONSE_COMPLETE_FINAL) {
inflight_step_counter_--;
}
RETURN_IF_TRITONSERVER_ERROR(completed_step->infer_status_);
updated_tensors->swap(completed_step->updated_tensors_);
}
return Status::Success;
}
Status
EnsembleContext::GetNextSteps(
const std::set<std::pair<std::string, IterationCount>>& updated_tensors,
StepList* steps)
{
steps->clear();
std::set<std::pair<size_t, IterationCount>> next_step_idx;
// Get steps whose tensors used for input are set
for (const auto updated_tensor : updated_tensors) {
const auto& step_idx = (*tensor_to_step_)[updated_tensor.first];
for (const auto& idx : step_idx) {
bool ready = true;
for (const auto& input_pair : info_->steps_[idx].input_to_tensor_) {
auto& tensor = tensor_data_[input_pair.second].tensor_;
if (tensor.empty()) {
ready = false;
break;
} else {
// Check if other inputs have tensor with corresponding iteration
// count
if (tensor.find(updated_tensor.second) == tensor.end()) {
ready = false;
break;
}
}
}
if (ready) {
next_step_idx.emplace(idx, updated_tensor.second);
}
}
}
for (const auto& idx : next_step_idx) {
steps->emplace_back();
RETURN_IF_ERROR(InitStep(idx.first, idx.second, &(steps->back())));
}
inflight_step_counter_ += steps->size();
return Status::Success;
}
Status
EnsembleContext::InitStep(
const size_t step_idx, const IterationCount iteration_count,
std::unique_ptr<Step>* step)
{
const auto& istep = info_->steps_[step_idx];
auto& version_map = handles_[istep.model_name_];
auto& model = version_map[istep.model_version_];
const bool allow_batching = (model->Config().max_batch_size() > 0);
auto irequest = std::unique_ptr<InferenceRequest>(
new InferenceRequest(model, istep.model_version_));
// Store the pointers to tensors used so that we can prune them afterward.
// Can't prune the tensor in the input loop below as it may be used by
// multiple inputs in the same step.
std::map<TensorData*, size_t*> releasing_tensors;
// Set inputs in request, prepare input map,
// and set overridden parameter if any.
auto correlation_id = correlation_id_;
auto flags = flags_;
bool parameter_set = false;
for (const auto& pair : istep.input_to_tensor_) {
auto& tensor_data = tensor_data_[pair.second];
auto& tensor = tensor_data.tensor_[iteration_count];
// nullptr if and only if the tensor is optional ensemble input and
// not provided in the ensemble request. In such case, we don't add
// the input and expect the ensemble pipeline is configured correctly
// (the input to the inner model is also optional)
if (tensor.data_ != nullptr) {
// If the actual shape and config shape agree with each other without
// considering batch size, non-batch / batch conversion are not required.
const inference::ModelInput* input_config;
model->GetInput(pair.first, &input_config);
auto shape = ReshapeTensorDims(
input_config->dims(), allow_batching, tensor_data.batch_size_,
tensor.data_->OriginalShape());
InferenceRequest::Input* input;
RETURN_IF_ERROR(irequest->AddOriginalInput(
pair.first, tensor.data_->DType(), shape, &input));
RETURN_IF_ERROR(input->SetData(tensor.data_->Data()));
for (const auto& host_policy_data : tensor.data_->HostPolicyData()) {
RETURN_IF_ERROR(
input->SetData(host_policy_data.first, host_policy_data.second));
}
}
releasing_tensors.emplace(&tensor_data, &tensor.remaining_reference_count_);
if (tensor.parameter_override_) {
if (parameter_set && ((correlation_id != tensor.correlation_id_) ||
(flags != tensor.flags_))) {
LOG_ERROR << irequest->LogRequest()
<< "Different set of response parameters are set for '"
<< istep.model_name_ << "'. Parameter correlation ID "
<< correlation_id << ", flags " << flags << " is used.";
continue;
}
correlation_id = tensor.correlation_id_;
flags = tensor.flags_;
parameter_set = true;
}
}
// Prune the tensor if it is not needed by other steps
for (auto& releasing_pair : releasing_tensors) {
if ((--(*releasing_pair.second)) == 0) {
releasing_pair.first->tensor_.erase(iteration_count);
}
}
// Set requested outputs in request header
for (const auto& pair : istep.output_to_tensor_) {
irequest->AddOriginalRequestedOutput(pair.first);
}
step->reset(new Step(step_idx, correlation_id, flags));
irequest->SetId(request_id_);
irequest->SetCorrelationId(correlation_id);
irequest->SetFlags(flags);
irequest->SetPriority(priority_);
irequest->SetTimeoutMicroseconds(timeout_);
#ifdef TRITON_ENABLE_STATS
irequest->SetSecondaryStatsAggregator(
&request_tracker_->ContextStatsAggregator());
#endif
irequest->SetResponseCallback(
reinterpret_cast<ResponseAllocator*>(allocator_.get()), step->get(),
ResponseComplete, step->get());
irequest->SetReleaseCallback(RequestComplete, request_tracker_);
RETURN_IF_ERROR(irequest->PrepareForInference());
#ifdef TRITON_ENABLE_TRACING
auto& parent_trace = request_tracker_->Request()->Trace();
if (parent_trace != nullptr) {
irequest->SetTrace(parent_trace->SpawnChildTrace());
irequest->Trace()->SetModelName(irequest->ModelName());
irequest->Trace()->SetModelVersion(irequest->ActualModelVersion());
}
#endif
// Record the batch size of output in advance as
// there is no other way to access it later on.
for (const auto& pair : istep.output_to_tensor_) {
auto& output_data_ = tensor_data_[pair.second];
output_data_.batch_size_ = irequest->BatchSize();
}
(*step)->request_ = std::move(irequest);
return Status::Success;
}
std::vector<int64_t>
EnsembleContext::ReshapeTensorDims(
const triton::common::DimsList& config_dims,
const bool config_allow_batching, const size_t tensor_batch_size,
const std::vector<int64_t>& dims)
{
bool reshaped = false;
std::vector<int64_t> res;
// Only attempt to reshape if one setting is batchable while the other is not,
// the case of two mismatched batchable shapes is not considered.
// If the actual shape and config shape agree with each other without
// considering batch size, non-batch / batch conversion are not required.
if (config_allow_batching != (tensor_batch_size != 0)) {
// expect batching but the tensor is generated from nobatching model
if (config_allow_batching) {
if (triton::common::CompareDimsWithWildcard(config_dims, dims)) {
// If 'dims' already matches 'config_dims', prepend with batch size 1
res.push_back(1);
res.insert(res.end(), dims.begin(), dims.end());
reshaped = true;
}
// Otherwise, assuming the tensor is already in the batch expected
// by the model and do nothing
} else {
// Check if the batched tensor can be sent to the non-batching
// model as one tensor. If not, strip the batch dimension if
// it is batch size 1
if (!triton::common::CompareDimsWithWildcard(config_dims, dims) &&
(tensor_batch_size == 1)) {
res.assign(dims.begin() + 1, dims.end());
reshaped = true;
}
}
}
if (!reshaped) {
res = dims;
}
return res;
}
Status
EnsembleContext::FinishEnsemble(std::unique_ptr<InferenceResponse>&& response)
{
// Do nothing if the ensemble is finished
if (request_tracker_ == nullptr) {
return ensemble_status_;
}
// Add ensemble name to make error message more trackable
if (!ensemble_status_.IsOk()) {
ensemble_status_ = Status(
ensemble_status_.StatusCode(), "in ensemble '" + info_->ensemble_name_ +
"', " + ensemble_status_.Message());
}
if (ensemble_status_.IsOk()) {
if (info_->is_decoupled_) {
if (response != nullptr) {
InferenceResponse::Send(std::move(response), 0 /* flags */);
}
if (inflight_step_counter_ != 0) {
return ensemble_status_;
}
request_tracker_->Request()->ResponseFactory()->SendFlags(
TRITONSERVER_RESPONSE_COMPLETE_FINAL);
} else {
InferenceResponse::Send(
std::move(response), TRITONSERVER_RESPONSE_COMPLETE_FINAL);
}
} else {
if (response != nullptr) {
InferenceResponse::SendWithStatus(
std::move(response), TRITONSERVER_RESPONSE_COMPLETE_FINAL,
ensemble_status_);
} else {
InferenceRequest::RespondIfError(
request_tracker_->Request(), ensemble_status_);
}
}
// Reach here when the ensemble execution comes to the end, 'ensemble_status_'
// at this point is representative.
request_tracker_->SetStatus(ensemble_status_);
if (request_tracker_->DecrementCounter()) {
delete request_tracker_;
}
request_tracker_ = nullptr;
return ensemble_status_;
}
Status
EnsembleContext::CheckAndSetEnsembleOutput(
const std::set<std::pair<std::string, IterationCount>>& updated_tensors,
std::unique_ptr<InferenceResponse>* response)
{
IterationCount iteration_count = 0;
// Check if updated tensor is one of the ensemble output and if all outputs
// have tensor of the same iteration count
bool ready = false;
auto& lrequest = request_tracker_->Request();
const auto& requested_outputs = lrequest->ImmutableRequestedOutputs();
for (const auto updated_tensor : updated_tensors) {
if (requested_outputs.find(updated_tensor.first) ==
requested_outputs.end()) {
continue;
}
ready = true;
iteration_count = updated_tensor.second;
for (const auto& output : requested_outputs) {
auto& tensor = tensor_data_[output].tensor_;
if (tensor.empty()) {
ready = false;
break;
} else {
// Check if other outputs have tensor with corresponding iteration count
if (tensor.find(iteration_count) == tensor.end()) {
ready = false;
break;
}
}
}
}
if (!ready) {
if (info_->is_decoupled_) {
return Status::Success;
}
return Status(
Status::Code::INVALID_ARG,
lrequest->LogRequest() +
"unexpected deadlock, at least one output is not set while no more "
"ensemble steps can be made");
}
RETURN_IF_ERROR(lrequest->ResponseFactory()->CreateResponse(response));
bool cuda_async_copy = false;
std::map<TensorData*, size_t*> releasing_tensors;
for (const auto& output_pair : info_->ensemble_output_shape_) {
if (requested_outputs.find(output_pair.first) == requested_outputs.end()) {
continue;
}
// Check if output is ready
auto& tensor_data = tensor_data_[output_pair.first];
auto& tensor = tensor_data.tensor_[iteration_count];
auto shape = ReshapeTensorDims(
output_pair.second, (lrequest->BatchSize() != 0),
tensor_data.batch_size_, tensor.data_->OriginalShape());
InferenceResponse::Output* output;
RETURN_IF_ERROR((*response)->AddOutput(
output_pair.first, tensor.data_->DType(), shape, &output));
// Use the memory type of the memory block as preferred memory type
TRITONSERVER_MemoryType dst_memory_type;
int64_t dst_memory_type_id;
size_t content_size;
tensor.data_->Data()->BufferAt(
0, &content_size, &dst_memory_type, &dst_memory_type_id);
void* buffer;
RETURN_IF_ERROR(output->AllocateDataBuffer(
&buffer, content_size, &dst_memory_type, &dst_memory_type_id));
// Done with this output if 'expected_byte_size' is 0
if (content_size == 0) {
continue;
} else if (buffer == nullptr) {
return Status(
Status::Code::INTERNAL,
"failed to allocate buffer for output '" + output_pair.first + "'");
}
size_t content_offset = 0;
size_t content_idx = 0;
TRITONSERVER_MemoryType src_memory_type;
int64_t src_memory_type_id;
const char* content = tensor.data_->Data()->BufferAt(
content_idx, &content_size, &src_memory_type, &src_memory_type_id);
bool cuda_used = false;
while (content != nullptr) {
RETURN_IF_ERROR(CopyBuffer(
output_pair.first, src_memory_type, src_memory_type_id,
dst_memory_type, dst_memory_type_id, content_size, content,
((char*)buffer) + content_offset, stream_, &cuda_used));
cuda_async_copy |= cuda_used;
content_offset += content_size;
content_idx++;
content = tensor.data_->Data()->BufferAt(
content_idx, &content_size, &src_memory_type, &src_memory_type_id);
}
releasing_tensors.emplace(&tensor_data, &tensor.remaining_reference_count_);
if (tensor.parameter_override_) {
switch (lrequest->CorrelationId().Type()) {
case InferenceRequest::SequenceId::DataType::STRING:
(*response)->AddParameter(
"sequence_id", tensor.correlation_id_.StringValue().c_str());
break;
case InferenceRequest::SequenceId::DataType::UINT64:
(*response)->AddParameter(
"sequence_id",
(int64_t)tensor.correlation_id_.UnsignedIntValue());
break;
default:
(*response)->AddParameter(
"sequence_id",
(int64_t)tensor.correlation_id_.UnsignedIntValue());
break;
}
(*response)->AddParameter(
"sequence_start",
(tensor.flags_ & TRITONSERVER_REQUEST_FLAG_SEQUENCE_START) != 0);
(*response)->AddParameter(
"sequence_end",
(tensor.flags_ & TRITONSERVER_REQUEST_FLAG_SEQUENCE_END) != 0);
}
}
if (cuda_async_copy) {
#ifdef TRITON_ENABLE_GPU
cudaStreamSynchronize(stream_);
#else
return Status(
Status::Code::INTERNAL,
"unexpected CUDA copy flag set while GPU is not supported");
#endif // TRITON_ENABLE_GPU
}
// Prune the tensor if it is not needed by other steps
for (auto& releasing_pair : releasing_tensors) {
if ((--(*releasing_pair.second)) == 0) {
releasing_pair.first->tensor_.erase(iteration_count);
}
}
return Status::Success;
}
void
EnsembleContext::ScheduleSteps(
const std::shared_ptr<EnsembleContext>& context, StepList&& steps)
{
for (auto& step : steps) {
step->ctx_ = context;
bool should_schedule = false;
// Must release lock before InferAsync to avoid deadlock, as the same thread
// will be calling request/response callbacks on cache hits, which will
// attempt to acquire the lock already held
{
std::lock_guard<std::mutex> lock(context->mutex_);
// Need to check the ensemble_status_ to ensure the FinishEnsemble()
// is called only once.
if (context->ensemble_status_.IsOk()) {
context->request_tracker_->IncrementCounter();
should_schedule = true;
}
}
if (should_schedule) {
// On a successful call to InferAsync(), the step will be released by
// the response callback. When the response callback is invoked, the
// step must not own (and release) the request as the request should be
// transferred and managed by Triton core. In the case of cache hit, the
// request hasn't been transferred and can cause double-free, so moving
// the request ownership out of step here to avoid that
std::unique_ptr<InferenceRequest> request = std::move(step->request_);
auto step_status = context->is_->InferAsync(request);
if (!step_status.IsOk()) {
std::lock_guard<std::mutex> lock(context->mutex_);
context->ensemble_status_ = step_status;
// The request is not sent to server properly, shouldn't expect its
// release function get called.
context->request_tracker_->DecrementCounter();
context->ensemble_status_ = context->FinishEnsemble();
break;
}
}
step.release();
}
}
} // namespace
Status
EnsembleScheduler::Create(
InferenceStatsAggregator* const stats_aggregator,
InferenceServer* const server, const inference::ModelConfig& config,
std::unique_ptr<Scheduler>* scheduler)
{
scheduler->reset(new EnsembleScheduler(stats_aggregator, server, config));
return Status::Success;
}
Status
EnsembleScheduler::Enqueue(std::unique_ptr<InferenceRequest>& request)
{
// Queue timer starts at the beginning of the queueing and
// scheduling process
request->CaptureQueueStartNs();
INFER_TRACE_ACTIVITY(
request->Trace(), TRITONSERVER_TRACE_QUEUE_START,
request->QueueStartNs());
#ifdef TRITON_ENABLE_TRACING
request->TraceInputTensors(
TRITONSERVER_TRACE_TENSOR_QUEUE_INPUT, "EnsembleScheduler Enqueue");
#endif // TRITON_ENABLE_TRACING
// Add additional callback to keep track of in-flight count
++inflight_count_;
request->AddInternalReleaseCallback([this]() { --inflight_count_; });
std::shared_ptr<EnsembleContext> context(new EnsembleContext(
metric_reporter_.get(), stats_aggregator_, is_, info_.get(), request,
stream_));
EnsembleContext::Proceed(context);
return Status::Success;
}
EnsembleScheduler::EnsembleScheduler(
InferenceStatsAggregator* const stats_aggregator,
InferenceServer* const server, const inference::ModelConfig& config)
: stats_aggregator_(stats_aggregator), is_(server), stream_(nullptr),
inflight_count_(0)
{
#ifdef TRITON_ENABLE_GPU
// create CUDA stream
auto cuerr = cudaStreamCreate(&stream_);
if (cuerr != cudaSuccess) {
stream_ = nullptr;
LOG_ERROR << "unable to create stream for " << config.name() << ": "
<< cudaGetErrorString(cuerr);
}
#endif // TRITON_ENABLE_GPU
#ifdef TRITON_ENABLE_METRICS
if (Metrics::Enabled()) {
MetricModelReporter::Create(
config.name(), 1, METRIC_REPORTER_ID_CPU, config.metric_tags(),
&metric_reporter_);
}
#endif // TRITON_ENABLE_METRICS
// Set 'info_' based on 'config'
info_.reset(new EnsembleInfo());
info_->ensemble_name_ = config.name();
// This config field is filled internally for ensemble models
info_->is_decoupled_ = config.model_transaction_policy().decoupled();
for (const auto& input : config.input()) {
info_->tensor_to_step_.emplace(input.name(), std::set<size_t>());
if (input.optional()) {
info_->optional_inputs_.emplace(input.name());
}
}
for (const auto& output : config.output()) {
info_->tensor_to_step_.emplace(output.name(), std::set<size_t>());
if (output.has_reshape()) {
info_->ensemble_output_shape_[output.name()] = output.reshape().shape();
} else {
info_->ensemble_output_shape_[output.name()] = output.dims();
}
}
for (const auto& element : config.ensemble_scheduling().step()) {
size_t step_idx = info_->steps_.size();
info_->steps_.emplace_back(element.model_name(), element.model_version());
for (const auto& pair : element.input_map()) {
auto it = info_->tensor_to_step_.find(pair.second);
if (it == info_->tensor_to_step_.end()) {
it = info_->tensor_to_step_.emplace(pair.second, std::set<size_t>())
.first;
}
it->second.insert(step_idx);
info_->steps_[step_idx].input_to_tensor_.emplace(
std::make_pair(pair.first, pair.second));
}
for (const auto& pair : element.output_map()) {
auto it = info_->tensor_to_step_.find(pair.second);
if (it == info_->tensor_to_step_.end()) {
it = info_->tensor_to_step_.emplace(pair.second, std::set<size_t>())
.first;
}
info_->steps_[step_idx].output_to_tensor_.emplace(
std::make_pair(pair.first, pair.second));
info_->tensor_to_prev_step_.emplace(pair.second, step_idx);
}
}
}
EnsembleScheduler::~EnsembleScheduler()
{
#ifdef TRITON_ENABLE_GPU
if (stream_ != nullptr) {
cudaError_t err = cudaStreamDestroy(stream_);
if (err != cudaSuccess) {
LOG_ERROR << "Failed to destroy cuda stream: " << cudaGetErrorString(err);
}
}
#endif // TRITON_ENABLE_GPU
}
}} // namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
// Copyright 2019-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef TRITON_ENABLE_ENSEMBLE
#include <memory>
#include "metric_model_reporter.h"
#include "model_config.pb.h"
#include "model_config_utils.h"
#include "scheduler.h"
#include "status.h"
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#endif // TRITON_ENABLE_GPU
namespace triton { namespace core {
#ifndef TRITON_ENABLE_GPU
using cudaStream_t = void*;
#endif // TRITON_ENABLE_GPU
class InferenceServer;
struct EnsembleInfo {
struct StepInfo {
StepInfo(const std::string& model_name, const int64_t model_version)
: model_name_(model_name), model_version_(model_version)
{
}
std::string model_name_;
int64_t model_version_;
std::unordered_map<std::string, std::string> input_to_tensor_;
std::unordered_map<std::string, std::string> output_to_tensor_;
};
std::string ensemble_name_;
bool is_decoupled_;
// the ensemble output (re)shape expected by the ensemble
std::unordered_map<std::string, triton::common::DimsList>
ensemble_output_shape_;
// Inputs that is marked optional for the ensemble
std::set<std::string> optional_inputs_;
std::vector<StepInfo> steps_;
// Only include a step if the ensemble tensor is used as input in that step
std::unordered_map<std::string, std::set<size_t>> tensor_to_step_;
// backward path, ensemble tensor to the step that provides its data
std::unordered_map<std::string, size_t> tensor_to_prev_step_;
};
// Scheduler that implements ensemble scheduling.
class EnsembleScheduler : public Scheduler {
public:
// Create a scheduler to process ensemble requests and
// to dispatch requests to models in ensemble internally.
static Status Create(
InferenceStatsAggregator* const stats_aggregator,
InferenceServer* const server, const inference::ModelConfig& config,
std::unique_ptr<Scheduler>* scheduler);
~EnsembleScheduler();
// \see Scheduler::Enqueue()
Status Enqueue(std::unique_ptr<InferenceRequest>& request) override;
// \see Scheduler::InflightInferenceCount()
size_t InflightInferenceCount() override { return inflight_count_; }
// \see Scheduler::Stop()
void Stop() override {}
private:
EnsembleScheduler(
InferenceStatsAggregator* const stats_aggregator,
InferenceServer* const server, const inference::ModelConfig& config);
std::shared_ptr<MetricModelReporter> metric_reporter_;
InferenceStatsAggregator* const stats_aggregator_;
InferenceServer* const is_;
// Ensemble information that is built from model config
std::unique_ptr<EnsembleInfo> info_;
// The stream used for data transfer.
cudaStream_t stream_;
std::atomic<size_t> inflight_count_;
};
}} // namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifdef TRITON_ENABLE_ENSEMBLE
#include "ensemble_utils.h"
#include <set>
#include "constants.h"
#include "model.h"
#include "model_config_utils.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
namespace {
/// A basic unit in ensemble graph that records the data type and shape
/// of the ensemble tensor and which model they are inferred from.
struct TensorNode {
TensorNode(
const std::string& model_name, const bool batching,
const inference::DataType& type, const triton::common::DimsList& dims)
: model_name_(model_name), type_(type), dims_(dims), is_decoupled_(false),
decouple_label_(0), visited_(false)
{
// Expand dims to full shape, which includes batch dimension if exist
if (batching) {
full_dims_.Add(-1);
}
full_dims_.MergeFrom(dims_);
}
// Constructor for symbolic nodes
TensorNode(const std::string& model_name)
: model_name_(model_name), is_decoupled_(false), decouple_label_(0),
visited_(false)
{
}
std::string model_name_;
inference::DataType type_;
triton::common::DimsList dims_;
triton::common::DimsList full_dims_;
bool is_decoupled_;
size_t decouple_label_;
bool visited_;
std::vector<TensorNode*> prev_nodes_;
std::vector<TensorNode*> next_nodes_;
// A symbolic node to keep track of the decouple label of nodes that
// are outputs of the same step.
std::shared_ptr<TensorNode> sibling_node_;
};
/// Validate if the data type and the shape of two TensorNode object are
/// consistent.
/// \param lhs One of the TensorNode object to be validated.
/// \param rhs Another TensorNode object to be validated.
/// \param message Extra message included in the front of error message
/// if error status is non-OK.
/// \return The error status. A non-OK status indicates the TensorNode objects
/// are not consistent.
Status
ValidateTensorConsistency(
const TensorNode& lhs, const TensorNode& rhs, const std::string& message)
{
if (lhs.type_ != rhs.type_) {
return Status(
Status::Code::INVALID_ARG,
message + "inconsistent data type: " +
inference::DataType_Name(lhs.type_) + " is inferred from model " +
lhs.model_name_ + " while " + inference::DataType_Name(rhs.type_) +
" is inferred from model " + rhs.model_name_);
}
// Shapes must match or either one uses variable size shape, if one uses
// variable size shape, shape consistency will be checked at runtime.
// If dims mismatch, compare agian with full dims in case the tensor is
// used for both non-batching model and batching model. In that case, it
// is acceptable if non-batching model shape is [-1, d_0, d_1, ..., d_n]
// while the batching model shape is [d_0, d_1, ..., d_n].
if (!triton::common::CompareDimsWithWildcard(lhs.dims_, rhs.dims_) &&
!triton::common::CompareDimsWithWildcard(
lhs.full_dims_, rhs.full_dims_)) {
return Status(
Status::Code::INVALID_ARG,
message + "inconsistent shape: " +
triton::common::DimsListToString(lhs.full_dims_) +
" is inferred from model " + lhs.model_name_ + " while " +
triton::common::DimsListToString(rhs.full_dims_) +
" is inferred from model " + rhs.model_name_);
}
return Status::Success;
}
Status
ValidateTensorMapping(
const std::string& ensemble, const inference::ModelEnsembling::Step& step,
const inference::ModelConfig& model_config,
std::unordered_map<std::string, TensorNode>* ensemble_tensors)
{
const bool batching = (model_config.max_batch_size() > 0);
// Check all inputs are mapped and no mapping to invalid inputs
std::set<std::string> input_names;
for (const auto& model_input : model_config.input()) {
input_names.insert(model_input.name());
}
for (const auto& input_map : step.input_map()) {
if (input_names.find(input_map.first) == input_names.end()) {
return Status(
Status::Code::INVALID_ARG,
"in ensemble " + ensemble + ", ensemble tensor " + input_map.second +
" is mapping to non-existing input " + input_map.first +
" in model " + step.model_name());
}
}
for (const auto& model_input : model_config.input()) {
size_t mapped_cnt = 0;
for (const auto& input_map : step.input_map()) {
if (model_input.name() == input_map.first) {
TensorNode model_tensor(
step.model_name(), batching, model_input.data_type(),
model_input.dims());
auto it = ensemble_tensors->find(input_map.second);
if (it != ensemble_tensors->end()) {
RETURN_IF_ERROR(ValidateTensorConsistency(
it->second, model_tensor,
"in ensemble " + ensemble + ", ensemble tensor " +
input_map.second + ": "));
} else {
ensemble_tensors->emplace(
std::make_pair(input_map.second, model_tensor));
}
mapped_cnt++;
}
}
if (mapped_cnt == 0) {
// Allow the input to be excluded from ensemble if it is optional
if (model_input.optional()) {
continue;
}
return Status(
Status::Code::INVALID_ARG,
"in ensemble " + ensemble + ", input " + model_input.name() +
" in model " + model_config.name() +
" is not mapped to any ensemble tensors");
} else if (mapped_cnt > 1) {
return Status(
Status::Code::INVALID_ARG,
"in ensemble " + ensemble + ", input " + model_input.name() +
" in model " + model_config.name() +
" is mapped to multiple ensemble tensors");
}
}
// Check no multiple mappings to same ensemble tensor
// and no mapping from invalid outputs
std::set<std::string> output_names;
for (const auto& model_output : model_config.output()) {
output_names.insert(model_output.name());
}
for (const auto& output_map : step.output_map()) {
if (output_names.find(output_map.first) == output_names.end()) {
return Status(
Status::Code::INVALID_ARG,
"in ensemble " + ensemble + ", ensemble tensor " + output_map.second +
" is mapped from non-existing output " + output_map.first +
" in model " + step.model_name());
}
}
std::shared_ptr<TensorNode> sibling_node(new TensorNode(step.model_name()));
for (const auto& output_map : step.output_map()) {
size_t mapped_cnt = 0;
for (const auto& model_output : model_config.output()) {
if (model_output.name() == output_map.first) {
TensorNode model_tensor(
step.model_name(), batching, model_output.data_type(),
model_output.dims());
auto it = ensemble_tensors->find(output_map.second);
if (it != ensemble_tensors->end()) {
RETURN_IF_ERROR(ValidateTensorConsistency(
it->second, model_tensor,
"in ensemble " + ensemble + ", ensemble tensor " +
output_map.second + ": "));
} else {
it = ensemble_tensors
->emplace(std::make_pair(output_map.second, model_tensor))
.first;
}
it->second.sibling_node_ = sibling_node;
mapped_cnt++;
}
}
if (mapped_cnt > 1) {
return Status(
Status::Code::INVALID_ARG,
"in ensemble " + ensemble + ", multiple outputs in model " +
model_config.name() + " are mapped to the same ensemble tensor " +
output_map.second);
}
}
// link ensemble tensors
bool is_decoupled = model_config.model_transaction_policy().decoupled();
for (const auto& output_map : step.output_map()) {
auto& node = ensemble_tensors->find(output_map.second)->second;
node.is_decoupled_ = is_decoupled;
for (const auto& input_map : step.input_map()) {
auto& prev_node = ensemble_tensors->find(input_map.second)->second;
node.prev_nodes_.push_back(&prev_node);
prev_node.next_nodes_.push_back(&node);
}
}
return Status::Success;
}
} // namespace
Status
ValidateEnsembleConfig(
ModelRepositoryManager* model_repository_manager,
ModelRepositoryManager::DependencyNode* ensemble)
{
const auto& ensemble_config = ensemble->model_config_;
if (!ensemble_config.has_ensemble_scheduling()) {
return Status::Success;
}
const auto& ensemble_name = ensemble->model_name_;
const bool batching = (ensemble_config.max_batch_size() > 0);
std::unordered_map<std::string, TensorNode> ensemble_tensors;
for (const auto& input : ensemble_config.input()) {
const auto& dims =
input.has_reshape() ? input.reshape().shape() : input.dims();
TensorNode input_node(ensemble_name, batching, input.data_type(), dims);
ensemble_tensors.emplace(std::make_pair(input.name(), input_node));
}
TensorNode sink_node(ensemble_name);
for (const auto& output : ensemble_config.output()) {
const auto& dims =
output.has_reshape() ? output.reshape().shape() : output.dims();
TensorNode output_node(ensemble_name, batching, output.data_type(), dims);
auto it =
ensemble_tensors.emplace(std::make_pair(output.name(), output_node))
.first;
sink_node.prev_nodes_.emplace_back(&(it->second));
it->second.next_nodes_.emplace_back(&sink_node);
}
for (const auto& step : ensemble_config.ensemble_scheduling().step()) {
const auto& model_name = step.model_name();
inference::ModelConfig model_config;
for (auto& node : ensemble->upstreams_) {
if (model_name == node.first->model_name_) {
// Obtain completed config from model instance
std::shared_ptr<Model> model;
RETURN_IF_ERROR(
model_repository_manager->GetModel(model_name, -1, &model));
model_config = model->Config();
break;
}
}
// batchable ensemble can include non-batchable models as long as
// the expanded shapes are consistent
if ((model_config.max_batch_size() != 0) &&
(model_config.max_batch_size() < ensemble_config.max_batch_size())) {
return Status(
Status::Code::INVALID_ARG,
"ensemble " + ensemble_name + " allows maximum batch size " +
std::to_string(ensemble_config.max_batch_size()) +
", but it contains model " + model_name +
" which only allows maximum batch size to be " +
std::to_string(model_config.max_batch_size()));
}
RETURN_IF_ERROR(ValidateTensorMapping(
ensemble_name, step, model_config, &ensemble_tensors));
}
// Visit nodes and validate decoupled workflow if any
// check data flow
size_t decouple_label = 0;
std::deque<TensorNode*> current_iterators;
for (const auto& input : ensemble_config.input()) {
auto it = ensemble_tensors.find(input.name());
it->second.visited_ = true;
current_iterators.push_back(&(it->second));
}
while (!current_iterators.empty()) {
auto& current_node = current_iterators.front();
for (auto& next_node : current_node->next_nodes_) {
if (next_node->visited_) {
continue;
}
bool next_node_ready = true;
for (auto& prev_node : next_node->prev_nodes_) {
if (!prev_node->visited_) {
next_node_ready = false;
break;
}
}
if (next_node_ready) {
size_t prev_decouple_label = next_node->prev_nodes_[0]->decouple_label_;
for (auto& prev_node : next_node->prev_nodes_) {
if (prev_node->decouple_label_ != prev_decouple_label) {
return Status(
Status::Code::INVALID_ARG,
"in ensemble " + ensemble_name + ", step of model '" +
next_node->model_name_ +
"' receives inputs originated from different decoupled "
"models");
}
}
if (next_node->sibling_node_ != nullptr) {
if (next_node->sibling_node_->visited_) {
next_node->decouple_label_ =
next_node->sibling_node_->decouple_label_;
} else {
next_node->decouple_label_ = next_node->is_decoupled_
? ++decouple_label
: prev_decouple_label;
next_node->sibling_node_->decouple_label_ =
next_node->decouple_label_;
next_node->sibling_node_->visited_ = true;
}
} else {
next_node->decouple_label_ =
next_node->is_decoupled_ ? ++decouple_label : prev_decouple_label;
}
next_node->visited_ = true;
current_iterators.push_back(next_node);
}
}
current_iterators.pop_front();
}
ensemble->model_config_.mutable_model_transaction_policy()->set_decoupled(
decouple_label != 0);
return Status::Success;
}
}} // namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
// Copyright 2018-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef TRITON_ENABLE_ENSEMBLE
#include <deque>
#include <unordered_map>
#include "model_config.pb.h"
#include "model_repository_manager.h"
#include "status.h"
#include "triton/common/model_config.h"
namespace triton { namespace core {
/// Validate that the ensemble are specified correctly. Assuming that the
/// inputs and outputs specified in depending model configurations are accurate.
/// \param model_repository_manager The model manager to acquire model config.
/// \param ensemble The ensemble to be validated.
/// \return The error status.
Status ValidateEnsembleConfig(
ModelRepositoryManager* model_repository_manager,
ModelRepositoryManager::DependencyNode* ensemble);
}} // namespace triton::core
#endif // TRITON_ENABLE_ENSEMBLE
// Copyright 2019-2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "filesystem.h"
#ifdef _WIN32
// suppress the min and max definitions in Windef.h.
#define NOMINMAX
#include <Windows.h>
// _CRT_INTERNAL_NONSTDC_NAMES 1 before including Microsoft provided C Runtime
// library to expose declarations without "_" prefix to match POSIX style.
#define _CRT_INTERNAL_NONSTDC_NAMES 1
#include <direct.h>
#include <io.h>
#else
#include <dirent.h>
#include <unistd.h>
#endif
#ifdef TRITON_ENABLE_GCS
#include <google/cloud/storage/client.h>
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
#include <aws/core/Aws.h>
#include <aws/core/auth/AWSCredentialsProvider.h>
#include <aws/s3/S3Client.h>
#include <aws/s3/model/GetObjectRequest.h>
#include <aws/s3/model/HeadBucketRequest.h>
#include <aws/s3/model/HeadObjectRequest.h>
#include <aws/s3/model/ListObjectsRequest.h>
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
#include <blob/blob_client.h>
#include <storage_account.h>
#include <storage_credential.h>
#undef LOG_INFO
#undef LOG_WARNING
#endif // TRITON_ENABLE_AZURE_STORAGE
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/text_format.h>
#include <re2/re2.h>
#include <stdio.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <cerrno>
#include <fstream>
#include <mutex>
#include "constants.h"
#include "status.h"
#include "triton/common/logging.h"
#define TRITONJSON_STATUSTYPE triton::core::Status
#define TRITONJSON_STATUSRETURN(M) \
return triton::core::Status(triton::core::Status::Code::INTERNAL, (M))
#define TRITONJSON_STATUSSUCCESS triton::core::Status::Success
#include "triton/common/triton_json.h"
#ifdef _WIN32
// <sys/stat.h> in Windows doesn't define S_ISDIR macro
#if !defined(S_ISDIR) && defined(S_IFMT) && defined(S_IFDIR)
#define S_ISDIR(m) (((m)&S_IFMT) == S_IFDIR)
#endif
#define F_OK 0
#endif
namespace triton { namespace core {
namespace {
// Check if a local path is a directory. We need to use this in LocalFileSystem
// and LocalizedPath so have this common function.
Status
IsPathDirectory(const std::string& path, bool* is_dir)
{
*is_dir = false;
struct stat st;
if (stat(path.c_str(), &st) != 0) {
return Status(Status::Code::INTERNAL, "failed to stat file " + path);
}
*is_dir = S_ISDIR(st.st_mode);
return Status::Success;
}
} // namespace
LocalizedPath::~LocalizedPath()
{
if (!local_path_.empty()) {
bool is_dir = true;
IsDirectory(local_path_, &is_dir);
LOG_STATUS_ERROR(
DeletePath(is_dir ? local_path_ : DirName(local_path_)),
"failed to delete localized path");
}
}
namespace {
class FileSystem {
public:
virtual Status FileExists(const std::string& path, bool* exists) = 0;
virtual Status IsDirectory(const std::string& path, bool* is_dir) = 0;
virtual Status FileModificationTime(
const std::string& path, int64_t* mtime_ns) = 0;
virtual Status GetDirectoryContents(
const std::string& path, std::set<std::string>* contents) = 0;
virtual Status GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs) = 0;
virtual Status GetDirectoryFiles(
const std::string& path, std::set<std::string>* files) = 0;
virtual Status ReadTextFile(
const std::string& path, std::string* contents) = 0;
virtual Status LocalizePath(
const std::string& path, std::shared_ptr<LocalizedPath>* localized) = 0;
virtual Status WriteTextFile(
const std::string& path, const std::string& contents) = 0;
virtual Status WriteBinaryFile(
const std::string& path, const char* contents,
const size_t content_len) = 0;
virtual Status MakeDirectory(
const std::string& dir, const bool recursive) = 0;
virtual Status MakeTemporaryDirectory(std::string* temp_dir) = 0;
virtual Status DeletePath(const std::string& path) = 0;
};
class LocalFileSystem : public FileSystem {
public:
Status FileExists(const std::string& path, bool* exists) override;
Status IsDirectory(const std::string& path, bool* is_dir) override;
Status FileModificationTime(
const std::string& path, int64_t* mtime_ns) override;
Status GetDirectoryContents(
const std::string& path, std::set<std::string>* contents) override;
Status GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs) override;
Status GetDirectoryFiles(
const std::string& path, std::set<std::string>* files) override;
Status ReadTextFile(const std::string& path, std::string* contents) override;
Status LocalizePath(
const std::string& path,
std::shared_ptr<LocalizedPath>* localized) override;
Status WriteTextFile(
const std::string& path, const std::string& contents) override;
Status WriteBinaryFile(
const std::string& path, const char* contents,
const size_t content_len) override;
Status MakeDirectory(const std::string& dir, const bool recursive) override;
Status MakeTemporaryDirectory(std::string* temp_dir) override;
Status DeletePath(const std::string& path) override;
};
Status
LocalFileSystem::FileExists(const std::string& path, bool* exists)
{
*exists = (access(path.c_str(), F_OK) == 0);
return Status::Success;
}
Status
LocalFileSystem::IsDirectory(const std::string& path, bool* is_dir)
{
return IsPathDirectory(path, is_dir);
}
Status
LocalFileSystem::FileModificationTime(
const std::string& path, int64_t* mtime_ns)
{
struct stat st;
if (stat(path.c_str(), &st) != 0) {
return Status(Status::Code::INTERNAL, "failed to stat file " + path);
}
#ifdef _WIN32
// In Windows, st_mtime is in time_t
*mtime_ns = std::max(st.st_mtime, st.st_ctime);
#else
*mtime_ns =
std::max(TIMESPEC_TO_NANOS(st.st_mtim), TIMESPEC_TO_NANOS(st.st_ctim));
#endif
return Status::Success;
}
Status
LocalFileSystem::GetDirectoryContents(
const std::string& path, std::set<std::string>* contents)
{
#ifdef _WIN32
WIN32_FIND_DATA entry;
// Append "*" to obtain all files under 'path'
HANDLE dir = FindFirstFile(JoinPath({path, "*"}).c_str(), &entry);
if (dir == INVALID_HANDLE_VALUE) {
return Status(Status::Code::INTERNAL, "failed to open directory " + path);
}
if ((strcmp(entry.cFileName, ".") != 0) &&
(strcmp(entry.cFileName, "..") != 0)) {
contents->insert(entry.cFileName);
}
while (FindNextFile(dir, &entry)) {
if ((strcmp(entry.cFileName, ".") != 0) &&
(strcmp(entry.cFileName, "..") != 0)) {
contents->insert(entry.cFileName);
}
}
FindClose(dir);
#else
DIR* dir = opendir(path.c_str());
if (dir == nullptr) {
return Status(Status::Code::INTERNAL, "failed to open directory " + path);
}
struct dirent* entry;
while ((entry = readdir(dir)) != nullptr) {
std::string entryname = entry->d_name;
if ((entryname != ".") && (entryname != "..")) {
contents->insert(entryname);
}
}
closedir(dir);
#endif
return Status::Success;
}
Status
LocalFileSystem::GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs)
{
RETURN_IF_ERROR(GetDirectoryContents(path, subdirs));
// Erase non-directory entries...
for (auto iter = subdirs->begin(); iter != subdirs->end();) {
bool is_dir;
RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir));
if (!is_dir) {
iter = subdirs->erase(iter);
} else {
++iter;
}
}
return Status::Success;
}
Status
LocalFileSystem::GetDirectoryFiles(
const std::string& path, std::set<std::string>* files)
{
RETURN_IF_ERROR(GetDirectoryContents(path, files));
// Erase directory entries...
for (auto iter = files->begin(); iter != files->end();) {
bool is_dir;
RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir));
if (is_dir) {
iter = files->erase(iter);
} else {
++iter;
}
}
return Status::Success;
}
Status
LocalFileSystem::ReadTextFile(const std::string& path, std::string* contents)
{
std::ifstream in(path, std::ios::in | std::ios::binary);
if (!in) {
return Status(
Status::Code::INTERNAL,
"failed to open text file for read " + path + ": " + strerror(errno));
}
in.seekg(0, std::ios::end);
contents->resize(in.tellg());
in.seekg(0, std::ios::beg);
in.read(&(*contents)[0], contents->size());
in.close();
return Status::Success;
}
Status
LocalFileSystem::LocalizePath(
const std::string& path, std::shared_ptr<LocalizedPath>* localized)
{
// For local file system we don't actually need to download the
// directory or file. We use it in place.
localized->reset(new LocalizedPath(path));
return Status::Success;
}
Status
LocalFileSystem::WriteTextFile(
const std::string& path, const std::string& contents)
{
std::ofstream out(path, std::ios::out | std::ios::binary);
if (!out) {
return Status(
Status::Code::INTERNAL,
"failed to open text file for write " + path + ": " + strerror(errno));
}
out.write(&contents[0], contents.size());
out.close();
return Status::Success;
}
Status
LocalFileSystem::WriteBinaryFile(
const std::string& path, const char* contents, const size_t content_len)
{
std::ofstream out(path, std::ios::out | std::ios::binary);
if (!out) {
return Status(
Status::Code::INTERNAL, "failed to open binary file for write " + path +
": " + strerror(errno));
}
out.write(contents, content_len);
return Status::Success;
}
Status
LocalFileSystem::MakeDirectory(const std::string& dir, const bool recursive)
{
#ifdef _WIN32
if (mkdir(dir.c_str()) == -1)
#else
if (mkdir(dir.c_str(), S_IRWXU) == -1)
#endif
{
// Only allow the error due to parent directory does not exist
// if 'recursive' is requested
if ((errno == ENOENT) && (!dir.empty()) && recursive) {
RETURN_IF_ERROR(MakeDirectory(DirName(dir), recursive));
// Retry the creation
#ifdef _WIN32
if (mkdir(dir.c_str()) == -1)
#else
if (mkdir(dir.c_str(), S_IRWXU) == -1)
#endif
{
return Status(
Status::Code::INTERNAL, "Failed to create directory '" + dir +
"', errno:" + strerror(errno));
}
} else {
return Status(
Status::Code::INTERNAL,
"Failed to create directory '" + dir + "', errno:" + strerror(errno));
}
}
return Status::Success;
}
Status
LocalFileSystem::MakeTemporaryDirectory(std::string* temp_dir)
{
#ifdef _WIN32
char temp_path[MAX_PATH + 1];
size_t temp_path_length = GetTempPath(MAX_PATH + 1, temp_path);
if (temp_path_length == 0) {
return Status(
Status::Code::INTERNAL,
"Failed to get local directory for temporary files");
}
// There is no single operation like 'mkdtemp' in Windows, thus generating
// unique temporary directory is a process of getting temporary file name,
// deleting the file (file creation is side effect fo getting name), creating
// corresponding directory, so mutex is used to avoid possible race condition.
// However, it doesn't prevent other process on creating temporary file and
// thus the race condition may still happen. One possible solution is
// to reserve a temporary directory for the process and generate temporary
// model directories inside it.
static std::mutex mtx;
std::lock_guard<std::mutex> lk(mtx);
// Construct a std::string as filled 'temp_path' is not C string,
// and so that we can reuse 'temp_path' to hold the temp file name.
std::string temp_path_str(temp_path, temp_path_length);
if (GetTempFileName(temp_path_str.c_str(), "folder", 0, temp_path) == 0) {
return Status(Status::Code::INTERNAL, "Failed to create local temp folder");
}
*temp_dir = temp_path;
DeleteFile(temp_dir->c_str());
if (CreateDirectory(temp_dir->c_str(), NULL) == 0) {
return Status(
Status::Code::INTERNAL,
"Failed to create local temp folder: " + *temp_dir);
}
#else
std::string folder_template = "/tmp/folderXXXXXX";
char* res = mkdtemp(const_cast<char*>(folder_template.c_str()));
if (res == nullptr) {
return Status(
Status::Code::INTERNAL,
"Failed to create local temp folder: " + folder_template +
", errno:" + strerror(errno));
}
*temp_dir = res;
#endif
return Status::Success;
}
Status
LocalFileSystem::DeletePath(const std::string& path)
{
bool is_dir = false;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
if (is_dir) {
std::set<std::string> contents;
RETURN_IF_ERROR(GetDirectoryContents(path, &contents));
for (const auto& content : contents) {
RETURN_IF_ERROR(DeletePath(JoinPath({path, content})));
}
rmdir(path.c_str());
} else {
remove(path.c_str());
}
return Status::Success;
}
#if defined(TRITON_ENABLE_GCS) || defined(TRITON_ENABLE_S3) || \
defined(TRITON_ENABLE_AZURE_STORAGE)
// Helper function to take care of lack of trailing slashes
std::string
AppendSlash(const std::string& name)
{
if (name.empty() || (name.back() == '/')) {
return name;
}
return (name + "/");
}
#endif // TRITON_ENABLE_GCS || TRITON_ENABLE_S3 || TRITON_ENABLE_AZURE_STORAGE
#ifdef TRITON_ENABLE_GCS
namespace gcs = google::cloud::storage;
struct GCSCredential {
std::string path_;
GCSCredential(); // from env var
GCSCredential(triton::common::TritonJson::Value& cred_json);
};
GCSCredential::GCSCredential()
{
const char* path = std::getenv("GOOGLE_APPLICATION_CREDENTIALS");
path_ = (path != nullptr ? std::string(path) : "");
}
GCSCredential::GCSCredential(triton::common::TritonJson::Value& cred_json)
{
cred_json.AsString(&path_);
}
class GCSFileSystem : public FileSystem {
public:
GCSFileSystem(const GCSCredential& gs_cred);
// unify with S3/azure interface
GCSFileSystem(const std::string& path, const GCSCredential& gs_cred)
: GCSFileSystem(gs_cred)
{
}
Status CheckClient();
// unify with S3 interface
Status CheckClient(const std::string& path) { return CheckClient(); }
Status FileExists(const std::string& path, bool* exists) override;
Status IsDirectory(const std::string& path, bool* is_dir) override;
Status FileModificationTime(
const std::string& path, int64_t* mtime_ns) override;
Status GetDirectoryContents(
const std::string& path, std::set<std::string>* contents) override;
Status GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs) override;
Status GetDirectoryFiles(
const std::string& path, std::set<std::string>* files) override;
Status ReadTextFile(const std::string& path, std::string* contents) override;
Status LocalizePath(
const std::string& path,
std::shared_ptr<LocalizedPath>* localized) override;
Status WriteTextFile(
const std::string& path, const std::string& contents) override;
Status WriteBinaryFile(
const std::string& path, const char* contents,
const size_t content_len) override;
Status MakeDirectory(const std::string& dir, const bool recursive) override;
Status MakeTemporaryDirectory(std::string* temp_dir) override;
Status DeletePath(const std::string& path) override;
private:
Status ParsePath(
const std::string& path, std::string* bucket, std::string* object);
Status MetaDataExists(
const std::string path, bool* exists,
google::cloud::StatusOr<gcs::ObjectMetadata>* metadata);
google::cloud::StatusOr<gcs::Client> client_;
};
GCSFileSystem::GCSFileSystem(const GCSCredential& gs_cred)
{
auto creds = gcs::oauth2::CreateServiceAccountCredentialsFromJsonFilePath(
gs_cred.path_);
if (creds) {
client_ = gcs::Client(gcs::ClientOptions(*creds));
}
}
Status
GCSFileSystem::CheckClient()
{
if (!client_) {
return Status(
Status::Code::INTERNAL,
"Unable to create GCS client. Check account credentials.");
}
return Status::Success;
}
Status
GCSFileSystem::ParsePath(
const std::string& path, std::string* bucket, std::string* object)
{
// Get the bucket name and the object path. Return error if input is malformed
int bucket_start = path.find("gs://") + strlen("gs://");
int bucket_end = path.find("/", bucket_start);
// If there isn't a second slash, the address has only the bucket
if (bucket_end > bucket_start) {
*bucket = path.substr(bucket_start, bucket_end - bucket_start);
*object = path.substr(bucket_end + 1);
} else {
*bucket = path.substr(bucket_start);
*object = "";
}
if (bucket->empty()) {
return Status(
Status::Code::INTERNAL, "No bucket name found in path: " + path);
}
return Status::Success;
}
Status
GCSFileSystem::FileExists(const std::string& path, bool* exists)
{
*exists = false;
std::string bucket, object;
RETURN_IF_ERROR(ParsePath(path, &bucket, &object));
// Make a request for metadata and check the response
google::cloud::StatusOr<gcs::ObjectMetadata> object_metadata =
client_->GetObjectMetadata(bucket, object);
if (object_metadata) {
*exists = true;
return Status::Success;
}
// GCS doesn't make objects for directories, so it could still be a directory
bool is_dir;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
*exists = is_dir;
return Status::Success;
}
Status
GCSFileSystem::IsDirectory(const std::string& path, bool* is_dir)
{
*is_dir = false;
std::string bucket, object_path;
RETURN_IF_ERROR(ParsePath(path, &bucket, &object_path));
// Check if the bucket exists
google::cloud::StatusOr<gcs::BucketMetadata> bucket_metadata =
client_->GetBucketMetadata(bucket);
if (!bucket_metadata) {
return Status(
Status::Code::INTERNAL, "Could not get MetaData for bucket with name " +
bucket + " : " +
bucket_metadata.status().message());
}
// Root case - bucket exists and object path is empty
if (object_path.empty()) {
*is_dir = true;
return Status::Success;
}
// Check whether it has children. If at least one child, it is a directory
for (auto&& object_metadata :
client_->ListObjects(bucket, gcs::Prefix(AppendSlash(object_path)))) {
if (object_metadata) {
*is_dir = true;
break;
}
}
return Status::Success;
}
Status
GCSFileSystem::FileModificationTime(const std::string& path, int64_t* mtime_ns)
{
// We don't need to worry about the case when this is a directory
bool is_dir;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
if (is_dir) {
*mtime_ns = 0;
return Status::Success;
}
std::string bucket, object;
RETURN_IF_ERROR(ParsePath(path, &bucket, &object));
// Otherwise check the object metadata for update time
google::cloud::StatusOr<gcs::ObjectMetadata> object_metadata =
client_->GetObjectMetadata(bucket, object);
if (!object_metadata) {
return Status(
Status::Code::INTERNAL, "Failed to get metadata for " + object + " : " +
object_metadata.status().message());
}
// Get duration from time point with respect to object clock
auto update_time = std::chrono::time_point_cast<std::chrono::nanoseconds>(
object_metadata->updated())
.time_since_epoch()
.count();
*mtime_ns = update_time;
return Status::Success;
}
Status
GCSFileSystem::GetDirectoryContents(
const std::string& path, std::set<std::string>* contents)
{
std::string bucket, dir_path;
RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path));
// Append a slash to make it easier to list contents
std::string full_dir = AppendSlash(dir_path);
// Get objects with prefix equal to full directory path
for (auto&& object_metadata :
client_->ListObjects(bucket, gcs::Prefix(full_dir))) {
if (!object_metadata) {
return Status(
Status::Code::INTERNAL, "Could not list contents of directory at " +
path + " : " +
object_metadata.status().message());
}
// In the case of empty directories, the directory itself will appear here
if (object_metadata->name() == full_dir) {
continue;
}
// We have to make sure that subdirectory contents do not appear here
std::string name = object_metadata->name();
int item_start = name.find(full_dir) + full_dir.size();
// GCS response prepends parent directory name
int item_end = name.find("/", item_start);
// Let set take care of subdirectory contents
std::string item = name.substr(item_start, item_end - item_start);
contents->insert(item);
}
return Status::Success;
}
Status
GCSFileSystem::GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs)
{
RETURN_IF_ERROR(GetDirectoryContents(path, subdirs));
// Erase non-directory entries...
for (auto iter = subdirs->begin(); iter != subdirs->end();) {
bool is_dir;
RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir));
if (!is_dir) {
iter = subdirs->erase(iter);
} else {
++iter;
}
}
return Status::Success;
}
Status
GCSFileSystem::GetDirectoryFiles(
const std::string& path, std::set<std::string>* files)
{
RETURN_IF_ERROR(GetDirectoryContents(path, files));
// Erase directory entries...
for (auto iter = files->begin(); iter != files->end();) {
bool is_dir;
RETURN_IF_ERROR(IsDirectory(JoinPath({path, *iter}), &is_dir));
if (is_dir) {
iter = files->erase(iter);
} else {
++iter;
}
}
return Status::Success;
}
Status
GCSFileSystem::ReadTextFile(const std::string& path, std::string* contents)
{
bool exists;
RETURN_IF_ERROR(FileExists(path, &exists));
if (!exists) {
return Status(Status::Code::INTERNAL, "File does not exist at " + path);
}
std::string bucket, object;
ParsePath(path, &bucket, &object);
gcs::ObjectReadStream stream = client_->ReadObject(bucket, object);
if (!stream) {
return Status(
Status::Code::INTERNAL, "Failed to open object read stream for " +
path + " : " + stream.status().message());
}
std::string data = "";
char c;
while (stream.get(c)) {
data += c;
}
*contents = data;
return Status::Success;
}
Status
GCSFileSystem::LocalizePath(
const std::string& path, std::shared_ptr<LocalizedPath>* localized)
{
bool exists;
RETURN_IF_ERROR(FileExists(path, &exists));
if (!exists) {
return Status(
Status::Code::INTERNAL, "directory or file does not exist at " + path);
}
bool is_dir;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
if (!is_dir) {
return Status(
Status::Code::UNSUPPORTED,
"GCS file localization not yet implemented " + path);
}
std::string tmp_folder;
RETURN_IF_ERROR(
triton::core::MakeTemporaryDirectory(FileSystemType::LOCAL, &tmp_folder));
localized->reset(new LocalizedPath(path, tmp_folder));
std::set<std::string> contents, filenames;
RETURN_IF_ERROR(GetDirectoryContents(path, &filenames));
for (auto itr = filenames.begin(); itr != filenames.end(); ++itr) {
contents.insert(JoinPath({path, *itr}));
}
while (contents.size() != 0) {
std::set<std::string> tmp_contents = contents;
contents.clear();
for (auto iter = tmp_contents.begin(); iter != tmp_contents.end(); ++iter) {
bool is_subdir;
std::string gcs_fpath = *iter;
std::string gcs_removed_path = gcs_fpath.substr(path.size());
std::string local_fpath =
JoinPath({(*localized)->Path(), gcs_removed_path});
RETURN_IF_ERROR(IsDirectory(gcs_fpath, &is_subdir));
if (is_subdir) {
// Create local mirror of sub-directories
#ifdef _WIN32
int status = mkdir(const_cast<char*>(local_fpath.c_str()));
#else
int status = mkdir(
const_cast<char*>(local_fpath.c_str()),
S_IRUSR | S_IWUSR | S_IXUSR);
#endif
if (status == -1) {
return Status(
Status::Code::INTERNAL,
"Failed to create local folder: " + local_fpath +
", errno:" + strerror(errno));
}
// Add sub-directories and deeper files to contents
std::set<std::string> subdir_contents;
RETURN_IF_ERROR(GetDirectoryContents(gcs_fpath, &subdir_contents));
for (auto itr = subdir_contents.begin(); itr != subdir_contents.end();
++itr) {
contents.insert(JoinPath({gcs_fpath, *itr}));
}
} else {
// Create local copy of file
std::string file_bucket, file_object;
RETURN_IF_ERROR(ParsePath(gcs_fpath, &file_bucket, &file_object));
// Send a request to read the object
gcs::ObjectReadStream filestream =
client_->ReadObject(file_bucket, file_object);
if (!filestream) {
return Status(
Status::Code::INTERNAL, "Failed to get object at " + *iter +
" : " +
filestream.status().message());
}
std::string gcs_removed_path = (*iter).substr(path.size());
std::string local_file_path =
JoinPath({(*localized)->Path(), gcs_removed_path});
std::ofstream output_file(local_file_path.c_str(), std::ios::binary);
output_file << filestream.rdbuf();
output_file.close();
}
}
}
return Status::Success;
}
Status
GCSFileSystem::WriteTextFile(
const std::string& path, const std::string& contents)
{
return Status(
Status::Code::UNSUPPORTED,
"Write text file operation not yet implemented " + path);
}
Status
GCSFileSystem::WriteBinaryFile(
const std::string& path, const char* contents, const size_t content_len)
{
return Status(
Status::Code::UNSUPPORTED,
"Write text file operation not yet implemented " + path);
}
Status
GCSFileSystem::MakeDirectory(const std::string& dir, const bool recursive)
{
return Status(
Status::Code::UNSUPPORTED,
"Make temporary directory operation not yet implemented");
}
Status
GCSFileSystem::MakeTemporaryDirectory(std::string* temp_dir)
{
return Status(
Status::Code::UNSUPPORTED,
"Make temporary directory operation not yet implemented");
}
Status
GCSFileSystem::DeletePath(const std::string& path)
{
return Status(
Status::Code::UNSUPPORTED, "Delete path operation not yet implemented");
}
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_AZURE_STORAGE
namespace as = azure::storage_lite;
const std::string AS_URL_PATTERN = "as://([^/]+)/([^/?]+)(?:/([^?]*))?(\\?.*)?";
struct ASCredential {
std::string account_str_;
std::string account_key_;
ASCredential(); // from env var
ASCredential(triton::common::TritonJson::Value& cred_json);
};
ASCredential::ASCredential()
{
const auto to_str = [](const char* s) -> std::string {
return (s != nullptr ? std::string(s) : "");
};
const char* account_str = std::getenv("AZURE_STORAGE_ACCOUNT");
const char* account_key = std::getenv("AZURE_STORAGE_KEY");
account_str_ = to_str(account_str);
account_key_ = to_str(account_key);
}
ASCredential::ASCredential(triton::common::TritonJson::Value& cred_json)
{
triton::common::TritonJson::Value account_str_json, account_key_json;
if (cred_json.Find("account_str", &account_str_json))
account_str_json.AsString(&account_str_);
if (cred_json.Find("account_key", &account_key_json))
account_key_json.AsString(&account_key_);
}
class ASFileSystem : public FileSystem {
public:
ASFileSystem(const std::string& path, const ASCredential& as_cred);
Status CheckClient();
// unify with S3 interface
Status CheckClient(const std::string& path) { return CheckClient(); }
Status FileExists(const std::string& path, bool* exists) override;
Status IsDirectory(const std::string& path, bool* is_dir) override;
Status FileModificationTime(
const std::string& path, int64_t* mtime_ns) override;
Status GetDirectoryContents(
const std::string& path, std::set<std::string>* contents) override;
Status GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs) override;
Status GetDirectoryFiles(
const std::string& path, std::set<std::string>* files) override;
Status ReadTextFile(const std::string& path, std::string* contents) override;
Status LocalizePath(
const std::string& path,
std::shared_ptr<LocalizedPath>* localized) override;
Status WriteTextFile(
const std::string& path, const std::string& contents) override;
Status WriteBinaryFile(
const std::string& path, const char* contents,
const size_t content_len) override;
Status MakeDirectory(const std::string& dir, const bool recursive) override;
Status MakeTemporaryDirectory(std::string* temp_dir) override;
Status DeletePath(const std::string& path) override;
private:
Status ParsePath(
const std::string& path, std::string* bucket, std::string* object);
std::shared_ptr<as::blob_client> client_;
Status ListDirectory(
const std::string& path, const std::string& dir_path,
std::function<
Status(const as::list_blobs_segmented_item&, const std::string&)>
func);
Status DownloadFolder(
const std::string& container, const std::string& path,
const std::string& dest);
re2::RE2 as_regex_;
};
Status
ASFileSystem::ParsePath(
const std::string& path, std::string* container, std::string* object)
{
std::string host_name, query;
if (!RE2::FullMatch(path, as_regex_, &host_name, container, object, &query)) {
return Status(
Status::Code::INTERNAL, "Invalid azure storage path: " + path);
}
return Status::Success;
}
ASFileSystem::ASFileSystem(const std::string& path, const ASCredential& as_cred)
: as_regex_(AS_URL_PATTERN)
{
std::shared_ptr<as::storage_account> account = nullptr;
std::string host_name, container, blob_path, query;
if (RE2::FullMatch(
path, as_regex_, &host_name, &container, &blob_path, &query)) {
size_t pos = host_name.rfind(".blob.core.windows.net");
std::string account_name;
if (as_cred.account_str_.empty()) {
if (pos != std::string::npos) {
account_name = host_name.substr(0, pos);
} else {
account_name = host_name;
}
} else {
account_name = as_cred.account_str_;
}
std::shared_ptr<as::storage_credential> cred;
if (!as_cred.account_key_.empty()) {
// Shared Key
cred = std::make_shared<as::shared_key_credential>(
account_name, as_cred.account_key_);
} else {
cred = std::make_shared<as::anonymous_credential>();
}
account = std::make_shared<as::storage_account>(
account_name, cred, /* use_https */ true);
client_ =
std::make_shared<as::blob_client>(account, /*max_concurrency*/ 16);
}
}
Status
ASFileSystem::CheckClient()
{
if (client_ == nullptr) {
return Status(
Status::Code::INTERNAL,
"Unable to create Azure filesystem client. Check account credentials.");
}
return Status::Success;
}
Status
ASFileSystem::FileModificationTime(const std::string& path, int64_t* mtime_ns)
{
as::blob_client_wrapper bc(client_);
std::string container, object_path;
RETURN_IF_ERROR(ParsePath(path, &container, &object_path));
auto blobProperty = bc.get_blob_property(container, object_path);
if (errno != 0) {
return Status(
Status::Code::INTERNAL, "Unable to get blob property for file at " +
path + ", errno:" + strerror(errno));
}
auto time =
std::chrono::system_clock::from_time_t(blobProperty.last_modified);
auto update_time =
std::chrono::time_point_cast<std::chrono::nanoseconds>(time)
.time_since_epoch()
.count();
*mtime_ns = update_time;
return Status::Success;
};
Status
ASFileSystem::ListDirectory(
const std::string& container, const std::string& dir_path,
std::function<
Status(const as::list_blobs_segmented_item&, const std::string&)>
func)
{
as::blob_client_wrapper bc(client_);
// Append a slash to make it easier to list contents
std::string full_dir = AppendSlash(dir_path);
auto blobs = bc.list_blobs_segmented(container, "/", "", full_dir);
if (errno != 0) {
return Status(
Status::Code::INTERNAL, "Failed to get contents of directory " +
dir_path + ", errno:" + strerror(errno));
}
for (auto&& item : blobs.blobs) {
std::string name = item.name;
int item_start = name.find(full_dir) + full_dir.size();
int item_end = name.find("/", item_start);
// Let set take care of subdirectory contents
std::string subfile = name.substr(item_start, item_end - item_start);
auto status = func(item, subfile);
if (!status.IsOk()) {
return status;
}
}
return Status::Success;
}
Status
ASFileSystem::GetDirectoryContents(
const std::string& path, std::set<std::string>* contents)
{
auto func = [&](const as::list_blobs_segmented_item& item,
const std::string& dir) {
contents->insert(dir);
return Status::Success;
};
std::string container, dir_path;
RETURN_IF_ERROR(ParsePath(path, &container, &dir_path));
return ListDirectory(container, dir_path, func);
}
Status
ASFileSystem::GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs)
{
auto func = [&](const as::list_blobs_segmented_item& item,
const std::string& dir) {
if (item.is_directory) {
subdirs->insert(dir);
}
return Status::Success;
};
std::string container, dir_path;
RETURN_IF_ERROR(ParsePath(path, &container, &dir_path));
return ListDirectory(container, dir_path, func);
}
Status
ASFileSystem::GetDirectoryFiles(
const std::string& path, std::set<std::string>* files)
{
auto func = [&](const as::list_blobs_segmented_item& item,
const std::string& file) {
if (!item.is_directory) {
files->insert(file);
}
return Status::Success;
};
std::string container, dir_path;
RETURN_IF_ERROR(ParsePath(path, &container, &dir_path));
return ListDirectory(container, dir_path, func);
}
Status
ASFileSystem::IsDirectory(const std::string& path, bool* is_dir)
{
*is_dir = false;
std::string container, object_path;
RETURN_IF_ERROR(ParsePath(path, &container, &object_path));
as::blob_client_wrapper bc(client_);
auto blobs = bc.list_blobs_segmented(container, "/", "", object_path, 1);
if (errno != 0) {
return Status(
Status::Code::INTERNAL, "Failed to check if directory at " + path +
", errno:" + strerror(errno));
}
*is_dir = blobs.blobs.size() > 0;
return Status::Success;
};
Status
ASFileSystem::ReadTextFile(const std::string& path, std::string* contents)
{
as::blob_client_wrapper bc(client_);
std::string container, object_path;
RETURN_IF_ERROR(ParsePath(path, &container, &object_path));
using namespace azure::storage_lite;
std::ostringstream out_stream;
bc.download_blob_to_stream(container, object_path, 0, 0, out_stream);
if (errno != 0) {
return Status(
Status::Code::INTERNAL, "Failed to fetch file stream at " + path +
", errno:" + strerror(errno));
}
*contents = out_stream.str();
return Status::Success;
}
Status
ASFileSystem::FileExists(const std::string& path, bool* exists)
{
*exists = false;
std::string container, object;
RETURN_IF_ERROR(ParsePath(path, &container, &object));
as::blob_client_wrapper bc(client_);
auto blobs = bc.list_blobs_segmented(container, "/", "", object, 1);
if (errno != 0) {
return Status(
Status::Code::INTERNAL, "Failed to check if file exists at " + path +
", errno:" + strerror(errno));
}
if (blobs.blobs.size() > 0) {
*exists = true;
}
return Status::Success;
}
Status
ASFileSystem::DownloadFolder(
const std::string& container, const std::string& path,
const std::string& dest)
{
as::blob_client_wrapper bc(client_);
auto func = [&](const as::list_blobs_segmented_item& item,
const std::string& dir) {
auto local_path = JoinPath({dest, dir});
auto blob_path = JoinPath({path, dir});
if (item.is_directory) {
int status = mkdir(
const_cast<char*>(local_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR);
if (status == -1) {
return Status(
Status::Code::INTERNAL,
"Failed to create local folder: " + local_path +
", errno:" + strerror(errno));
}
auto ret = DownloadFolder(container, blob_path, local_path);
if (!ret.IsOk()) {
return ret;
}
} else {
time_t last_modified;
bc.download_blob_to_file(container, blob_path, local_path, last_modified);
if (errno != 0) {
return Status(
Status::Code::INTERNAL, "Failed to download file at " + blob_path +
", errno:" + strerror(errno));
}
}
return Status::Success;
};
return ListDirectory(container, path, func);
}
Status
ASFileSystem::LocalizePath(
const std::string& path, std::shared_ptr<LocalizedPath>* localized)
{
bool exists;
RETURN_IF_ERROR(FileExists(path, &exists));
if (!exists) {
return Status(
Status::Code::INTERNAL, "directory or file does not exist at " + path);
}
bool is_dir;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
if (!is_dir) {
return Status(
Status::Code::UNSUPPORTED,
"AS file localization not yet implemented " + path);
}
std::string folder_template = "/tmp/folderXXXXXX";
char* tmp_folder = mkdtemp(const_cast<char*>(folder_template.c_str()));
if (tmp_folder == nullptr) {
return Status(
Status::Code::INTERNAL,
"Failed to create local temp folder: " + folder_template +
", errno:" + strerror(errno));
}
localized->reset(new LocalizedPath(path, tmp_folder));
std::string dest(folder_template);
as::blob_client_wrapper bc(client_);
std::string container, object;
RETURN_IF_ERROR(ParsePath(path, &container, &object));
return DownloadFolder(container, object, dest);
}
Status
ASFileSystem::WriteTextFile(
const std::string& path, const std::string& contents)
{
std::stringstream ss(contents);
std::istream is(ss.rdbuf());
std::string container, object;
RETURN_IF_ERROR(ParsePath(path, &container, &object));
std::vector<std::pair<std::string, std::string>> metadata;
auto ret =
client_->upload_block_blob_from_stream(container, object, is, metadata)
.get();
if (!ret.success()) {
return Status(
Status::Code::INTERNAL,
"Failed to upload blob, Error: " + ret.error().code + ", " +
ret.error().code_name);
}
return Status::Success;
}
Status
ASFileSystem::WriteBinaryFile(
const std::string& path, const char* contents, const size_t content_len)
{
return Status(
Status::Code::UNSUPPORTED,
"Write text file operation not yet implemented " + path);
}
Status
ASFileSystem::MakeDirectory(const std::string& dir, const bool recursive)
{
return Status(
Status::Code::UNSUPPORTED,
"Make directory operation not yet implemented");
}
Status
ASFileSystem::MakeTemporaryDirectory(std::string* temp_dir)
{
return Status(
Status::Code::UNSUPPORTED,
"Make temporary directory operation not yet implemented");
}
Status
ASFileSystem::DeletePath(const std::string& path)
{
return Status(
Status::Code::UNSUPPORTED, "Delete path operation not yet implemented");
}
#endif // TRITON_ENABLE_AZURE_STORAGE
#ifdef TRITON_ENABLE_S3
namespace s3 = Aws::S3;
struct S3Credential {
std::string secret_key_;
std::string key_id_;
std::string region_;
std::string session_token_;
std::string profile_name_;
S3Credential(); // from env var
S3Credential(triton::common::TritonJson::Value& cred_json);
};
S3Credential::S3Credential()
{
const auto to_str = [](const char* s) -> std::string {
return (s != nullptr ? std::string(s) : "");
};
const char* secret_key = std::getenv("AWS_SECRET_ACCESS_KEY");
const char* key_id = std::getenv("AWS_ACCESS_KEY_ID");
const char* region = std::getenv("AWS_DEFAULT_REGION");
const char* session_token = std::getenv("AWS_SESSION_TOKEN");
const char* profile = std::getenv("AWS_PROFILE");
secret_key_ = to_str(secret_key);
key_id_ = to_str(key_id);
region_ = to_str(region);
session_token_ = to_str(session_token);
profile_name_ = to_str(profile);
}
S3Credential::S3Credential(triton::common::TritonJson::Value& cred_json)
{
triton::common::TritonJson::Value secret_key_json, key_id_json, region_json,
session_token_json, profile_json;
if (cred_json.Find("secret_key", &secret_key_json))
secret_key_json.AsString(&secret_key_);
if (cred_json.Find("key_id", &key_id_json))
key_id_json.AsString(&key_id_);
if (cred_json.Find("region", &region_json))
region_json.AsString(&region_);
if (cred_json.Find("session_token", &session_token_json))
session_token_json.AsString(&session_token_);
if (cred_json.Find("profile", &profile_json))
profile_json.AsString(&profile_name_);
}
class S3FileSystem : public FileSystem {
public:
S3FileSystem(const std::string& s3_path, const S3Credential& s3_cred);
Status CheckClient(const std::string& s3_path);
Status FileExists(const std::string& path, bool* exists) override;
Status IsDirectory(const std::string& path, bool* is_dir) override;
Status FileModificationTime(
const std::string& path, int64_t* mtime_ns) override;
Status GetDirectoryContents(
const std::string& path, std::set<std::string>* contents) override;
Status GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs) override;
Status GetDirectoryFiles(
const std::string& path, std::set<std::string>* files) override;
Status ReadTextFile(const std::string& path, std::string* contents) override;
Status LocalizePath(
const std::string& path,
std::shared_ptr<LocalizedPath>* localized) override;
Status WriteTextFile(
const std::string& path, const std::string& contents) override;
Status WriteBinaryFile(
const std::string& path, const char* contents,
const size_t content_len) override;
Status MakeDirectory(const std::string& dir, const bool recursive) override;
Status MakeTemporaryDirectory(std::string* temp_dir) override;
Status DeletePath(const std::string& path) override;
private:
Status ParsePath(
const std::string& path, std::string* bucket, std::string* object);
Status CleanPath(const std::string& s3_path, std::string* clean_path);
std::unique_ptr<s3::S3Client> client_; // init after Aws::InitAPI is called
re2::RE2 s3_regex_;
};
Status
S3FileSystem::ParsePath(
const std::string& path, std::string* bucket, std::string* object)
{
// Cleanup extra slashes
std::string clean_path;
RETURN_IF_ERROR(CleanPath(path, &clean_path));
// Get the bucket name and the object path. Return error if path is malformed
std::string protocol, host_name, host_port;
if (!RE2::FullMatch(
clean_path, s3_regex_, &protocol, &host_name, &host_port, bucket,
object)) {
int bucket_start = clean_path.find("s3://") + strlen("s3://");
int bucket_end = clean_path.find("/", bucket_start);
// If there isn't a slash, the address has only the bucket
if (bucket_end > bucket_start) {
*bucket = clean_path.substr(bucket_start, bucket_end - bucket_start);
*object = clean_path.substr(bucket_end + 1);
} else {
*bucket = clean_path.substr(bucket_start);
*object = "";
}
} else {
// Erase leading '/' that is left behind in object name
if ((*object)[0] == '/') {
object->erase(0, 1);
}
}
if (bucket->empty()) {
return Status(
Status::Code::INTERNAL, "No bucket name found in path: " + path);
}
return Status::Success;
}
Status
S3FileSystem::CleanPath(const std::string& s3_path, std::string* clean_path)
{
// Must handle paths with s3 prefix
size_t start = s3_path.find("s3://");
std::string path = "";
if (start != std::string::npos) {
path = s3_path.substr(start + strlen("s3://"));
*clean_path = "s3://";
} else {
path = s3_path;
*clean_path = "";
}
// Must handle paths with https:// or http:// prefix
size_t https_start = path.find("https://");
if (https_start != std::string::npos) {
path = path.substr(https_start + strlen("https://"));
*clean_path += "https://";
} else {
size_t http_start = path.find("http://");
if (http_start != std::string::npos) {
path = path.substr(http_start + strlen("http://"));
*clean_path += "http://";
}
}
// Remove trailing slashes
size_t rtrim_length = path.find_last_not_of('/');
if (rtrim_length == std::string::npos) {
return Status(
Status::Code::INVALID_ARG, "Invalid bucket name: '" + path + "'");
}
// Remove leading slashes
size_t ltrim_length = path.find_first_not_of('/');
if (ltrim_length == std::string::npos) {
return Status(
Status::Code::INVALID_ARG, "Invalid bucket name: '" + path + "'");
}
// Remove extra internal slashes
std::string true_path = path.substr(ltrim_length, rtrim_length + 1);
std::vector<int> slash_locations;
bool previous_slash = false;
for (size_t i = 0; i < true_path.size(); i++) {
if (true_path[i] == '/') {
if (!previous_slash) {
*clean_path += true_path[i];
}
previous_slash = true;
} else {
*clean_path += true_path[i];
previous_slash = false;
}
}
return Status::Success;
}
S3FileSystem::S3FileSystem(
const std::string& s3_path, const S3Credential& s3_cred)
: s3_regex_(
"s3://(http://|https://|)([0-9a-zA-Z\\-.]+):([0-9]+)/"
"([0-9a-z.\\-]+)(((/[0-9a-zA-Z.\\-_]+)*)?)")
{
// init aws api if not already
Aws::SDKOptions options;
static std::once_flag onceFlag;
std::call_once(onceFlag, [&options] { Aws::InitAPI(options); });
Aws::Client::ClientConfiguration config;
Aws::Auth::AWSCredentials credentials;
// check vars for S3 credentials -> aws profile -> default
if (!s3_cred.secret_key_.empty() && !s3_cred.key_id_.empty()) {
credentials.SetAWSAccessKeyId(s3_cred.key_id_.c_str());
credentials.SetAWSSecretKey(s3_cred.secret_key_.c_str());
if (!s3_cred.session_token_.empty()) {
credentials.SetSessionToken(s3_cred.session_token_.c_str());
}
config = Aws::Client::ClientConfiguration();
if (!s3_cred.region_.empty()) {
config.region = s3_cred.region_.c_str();
}
} else if (!s3_cred.profile_name_.empty()) {
config = Aws::Client::ClientConfiguration(s3_cred.profile_name_.c_str());
} else {
config = Aws::Client::ClientConfiguration("default");
}
// Cleanup extra slashes
std::string clean_path;
LOG_STATUS_ERROR(CleanPath(s3_path, &clean_path), "failed to parse S3 path");
std::string protocol, host_name, host_port, bucket, object;
if (RE2::FullMatch(
clean_path, s3_regex_, &protocol, &host_name, &host_port, &bucket,
&object)) {
config.endpointOverride = Aws::String(host_name + ":" + host_port);
if (protocol == "https://") {
config.scheme = Aws::Http::Scheme::HTTPS;
} else {
config.scheme = Aws::Http::Scheme::HTTP;
}
}
if (!s3_cred.secret_key_.empty() && !s3_cred.key_id_.empty()) {
client_ = std::make_unique<s3::S3Client>(
credentials, config,
Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
/*useVirtualAdressing*/ false);
} else {
client_ = std::make_unique<s3::S3Client>(
config, Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
/*useVirtualAdressing*/ false);
}
}
Status
S3FileSystem::CheckClient(const std::string& s3_path)
{
std::string bucket, object_path;
RETURN_IF_ERROR(ParsePath(s3_path, &bucket, &object_path));
// check if can connect to the bucket
s3::Model::HeadBucketRequest head_request;
head_request.WithBucket(bucket.c_str());
if (!client_->HeadBucket(head_request).IsSuccess()) {
return Status(
Status::Code::INTERNAL,
"Unable to create S3 filesystem client. Check account credentials.");
}
return Status::Success;
}
Status
S3FileSystem::FileExists(const std::string& path, bool* exists)
{
*exists = false;
// S3 doesn't make objects for directories, so it could still be a directory
bool is_dir;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
if (is_dir) {
*exists = is_dir;
return Status::Success;
}
std::string bucket, object;
RETURN_IF_ERROR(ParsePath(path, &bucket, &object));
// Construct request for object metadata
s3::Model::HeadObjectRequest head_request;
head_request.SetBucket(bucket.c_str());
head_request.SetKey(object.c_str());
auto head_object_outcome = client_->HeadObject(head_request);
if (!head_object_outcome.IsSuccess()) {
if (head_object_outcome.GetError().GetErrorType() !=
s3::S3Errors::RESOURCE_NOT_FOUND) {
return Status(
Status::Code::INTERNAL,
"Could not get MetaData for object at " + path +
" due to exception: " +
head_object_outcome.GetError().GetExceptionName() +
", error message: " +
head_object_outcome.GetError().GetMessage());
}
} else {
*exists = true;
}
return Status::Success;
}
Status
S3FileSystem::IsDirectory(const std::string& path, bool* is_dir)
{
*is_dir = false;
std::string bucket, object_path;
RETURN_IF_ERROR(ParsePath(path, &bucket, &object_path));
// Check if the bucket exists
s3::Model::HeadBucketRequest head_request;
head_request.WithBucket(bucket.c_str());
auto head_bucket_outcome = client_->HeadBucket(head_request);
if (!head_bucket_outcome.IsSuccess()) {
return Status(
Status::Code::INTERNAL,
"Could not get MetaData for bucket with name " + bucket +
" due to exception: " +
head_bucket_outcome.GetError().GetExceptionName() +
", error message: " + head_bucket_outcome.GetError().GetMessage());
}
// Root case - bucket exists and object path is empty
if (object_path.empty()) {
*is_dir = true;
return Status::Success;
}
// List the objects in the bucket
s3::Model::ListObjectsRequest list_objects_request;
list_objects_request.SetBucket(bucket.c_str());
list_objects_request.SetPrefix(AppendSlash(object_path).c_str());
auto list_objects_outcome = client_->ListObjects(list_objects_request);
if (list_objects_outcome.IsSuccess()) {
*is_dir = !list_objects_outcome.GetResult().GetContents().empty();
} else {
return Status(
Status::Code::INTERNAL,
"Failed to list objects with prefix " + path + " due to exception: " +
list_objects_outcome.GetError().GetExceptionName() +
", error message: " + list_objects_outcome.GetError().GetMessage());
}
return Status::Success;
}
Status
S3FileSystem::FileModificationTime(const std::string& path, int64_t* mtime_ns)
{
// We don't need to worry about the case when this is a directory
bool is_dir;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
if (is_dir) {
*mtime_ns = 0;
return Status::Success;
}
std::string bucket, object;
RETURN_IF_ERROR(ParsePath(path, &bucket, &object));
// Send a request for the objects metadata
s3::Model::HeadObjectRequest head_request;
head_request.SetBucket(bucket.c_str());
head_request.SetKey(object.c_str());
// If request succeeds, copy over the modification time
auto head_object_outcome = client_->HeadObject(head_request);
if (head_object_outcome.IsSuccess()) {
*mtime_ns = head_object_outcome.GetResult().GetLastModified().Millis() *
NANOS_PER_MILLIS;
} else {
return Status(
Status::Code::INTERNAL,
"Failed to get modification time for object at " + path +
" due to exception: " +
head_object_outcome.GetError().GetExceptionName() +
", error message: " + head_object_outcome.GetError().GetMessage());
}
return Status::Success;
}
Status
S3FileSystem::GetDirectoryContents(
const std::string& path, std::set<std::string>* contents)
{
// Parse bucket and dir_path
std::string bucket, dir_path, full_dir;
RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path));
std::string true_path = "s3://" + bucket + '/' + dir_path;
// Capture the full path to facilitate content listing
full_dir = AppendSlash(dir_path);
// Issue request for objects with prefix
s3::Model::ListObjectsRequest objects_request;
objects_request.SetBucket(bucket.c_str());
objects_request.SetPrefix(full_dir.c_str());
auto list_objects_outcome = client_->ListObjects(objects_request);
if (list_objects_outcome.IsSuccess()) {
Aws::Vector<Aws::S3::Model::Object> object_list =
list_objects_outcome.GetResult().GetContents();
for (auto const& s3_object : object_list) {
// In the case of empty directories, the directory itself will appear here
if (s3_object.GetKey().c_str() == full_dir) {
continue;
}
// We have to make sure that subdirectory contents do not appear here
std::string name(s3_object.GetKey().c_str());
int item_start = name.find(full_dir) + full_dir.size();
// S3 response prepends parent directory name
int item_end = name.find("/", item_start);
// Let set take care of subdirectory contents
std::string item = name.substr(item_start, item_end - item_start);
contents->insert(item);
}
} else {
return Status(
Status::Code::INTERNAL,
"Could not list contents of directory at " + true_path +
" due to exception: " +
list_objects_outcome.GetError().GetExceptionName() +
", error message: " + list_objects_outcome.GetError().GetMessage());
}
return Status::Success;
}
Status
S3FileSystem::GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs)
{
// Parse bucket and dir_path
std::string bucket, dir_path;
RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path));
std::string true_path = "s3://" + bucket + '/' + dir_path;
RETURN_IF_ERROR(GetDirectoryContents(true_path, subdirs));
// Erase non-directory entries...
for (auto iter = subdirs->begin(); iter != subdirs->end();) {
bool is_dir;
RETURN_IF_ERROR(IsDirectory(JoinPath({true_path, *iter}), &is_dir));
if (!is_dir) {
iter = subdirs->erase(iter);
} else {
++iter;
}
}
return Status::Success;
}
Status
S3FileSystem::GetDirectoryFiles(
const std::string& path, std::set<std::string>* files)
{
// Parse bucket and dir_path
std::string bucket, dir_path;
RETURN_IF_ERROR(ParsePath(path, &bucket, &dir_path));
std::string true_path = "s3://" + bucket + '/' + dir_path;
RETURN_IF_ERROR(GetDirectoryContents(true_path, files));
// Erase directory entries...
for (auto iter = files->begin(); iter != files->end();) {
bool is_dir;
RETURN_IF_ERROR(IsDirectory(JoinPath({true_path, *iter}), &is_dir));
if (is_dir) {
iter = files->erase(iter);
} else {
++iter;
}
}
return Status::Success;
}
Status
S3FileSystem::ReadTextFile(const std::string& path, std::string* contents)
{
bool exists;
RETURN_IF_ERROR(FileExists(path, &exists));
if (!exists) {
return Status(Status::Code::INTERNAL, "File does not exist at " + path);
}
std::string bucket, object;
RETURN_IF_ERROR(ParsePath(path, &bucket, &object));
// Send a request for the objects metadata
s3::Model::GetObjectRequest object_request;
object_request.SetBucket(bucket.c_str());
object_request.SetKey(object.c_str());
auto get_object_outcome = client_->GetObject(object_request);
if (get_object_outcome.IsSuccess()) {
auto& object_result = get_object_outcome.GetResultWithOwnership().GetBody();
std::string data = "";
char c;
while (object_result.get(c)) {
data += c;
}
*contents = data;
} else {
return Status(
Status::Code::INTERNAL,
"Failed to get object at " + path + " due to exception: " +
get_object_outcome.GetError().GetExceptionName() +
", error message: " + get_object_outcome.GetError().GetMessage());
}
return Status::Success;
}
Status
S3FileSystem::LocalizePath(
const std::string& path, std::shared_ptr<LocalizedPath>* localized)
{
// Check if the directory or file exists
bool exists;
RETURN_IF_ERROR(FileExists(path, &exists));
if (!exists) {
return Status(
Status::Code::INTERNAL, "directory or file does not exist at " + path);
}
// Cleanup extra slashes
std::string clean_path;
RETURN_IF_ERROR(CleanPath(path, &clean_path));
// Remove protocol and host name and port
std::string effective_path, protocol, host_name, host_port, bucket, object;
if (RE2::FullMatch(
clean_path, s3_regex_, &protocol, &host_name, &host_port, &bucket,
&object)) {
effective_path = "s3://" + bucket + object;
} else {
effective_path = path;
}
// Create temporary directory
std::string tmp_folder;
RETURN_IF_ERROR(
triton::core::MakeTemporaryDirectory(FileSystemType::LOCAL, &tmp_folder));
// Specify contents to be downloaded
std::set<std::string> contents;
bool is_dir;
RETURN_IF_ERROR(IsDirectory(path, &is_dir));
if (is_dir) {
// Set localized path
localized->reset(new LocalizedPath(effective_path, tmp_folder));
// Specify the entire directory to be downloaded
std::set<std::string> filenames;
RETURN_IF_ERROR(GetDirectoryContents(effective_path, &filenames));
for (auto itr = filenames.begin(); itr != filenames.end(); ++itr) {
contents.insert(JoinPath({effective_path, *itr}));
}
} else {
// Set localized path
std::string filename =
effective_path.substr(effective_path.find_last_of('/') + 1);
localized->reset(
new LocalizedPath(effective_path, JoinPath({tmp_folder, filename})));
// Specify only the file to be downloaded
contents.insert(effective_path);
}
// Download all specified contents and nested contents
while (contents.size() != 0) {
std::set<std::string> tmp_contents = contents;
contents.clear();
for (auto iter = tmp_contents.begin(); iter != tmp_contents.end(); ++iter) {
std::string s3_fpath = *iter;
std::string s3_removed_path = s3_fpath.substr(effective_path.size());
std::string local_fpath =
s3_removed_path.empty()
? (*localized)->Path()
: JoinPath({(*localized)->Path(), s3_removed_path});
bool is_subdir;
RETURN_IF_ERROR(IsDirectory(s3_fpath, &is_subdir));
if (is_subdir) {
// Create local mirror of sub-directories
#ifdef _WIN32
int status = mkdir(const_cast<char*>(local_fpath.c_str()));
#else
int status = mkdir(
const_cast<char*>(local_fpath.c_str()),
S_IRUSR | S_IWUSR | S_IXUSR);
#endif
if (status == -1) {
return Status(
Status::Code::INTERNAL,
"Failed to create local folder: " + local_fpath +
", errno:" + strerror(errno));
}
// Add sub-directories and deeper files to contents
std::set<std::string> subdir_contents;
RETURN_IF_ERROR(GetDirectoryContents(s3_fpath, &subdir_contents));
for (auto itr = subdir_contents.begin(); itr != subdir_contents.end();
++itr) {
contents.insert(JoinPath({s3_fpath, *itr}));
}
} else {
// Create local copy of file
std::string file_bucket, file_object;
RETURN_IF_ERROR(ParsePath(s3_fpath, &file_bucket, &file_object));
s3::Model::GetObjectRequest object_request;
object_request.SetBucket(file_bucket.c_str());
object_request.SetKey(file_object.c_str());
auto get_object_outcome = client_->GetObject(object_request);
if (get_object_outcome.IsSuccess()) {
auto& retrieved_file =
get_object_outcome.GetResultWithOwnership().GetBody();
std::ofstream output_file(local_fpath.c_str(), std::ios::binary);
output_file << retrieved_file.rdbuf();
output_file.close();
} else {
return Status(
Status::Code::INTERNAL,
"Failed to get object at " + s3_fpath + " due to exception: " +
get_object_outcome.GetError().GetExceptionName() +
", error message: " +
get_object_outcome.GetError().GetMessage());
}
}
}
}
return Status::Success;
}
Status
S3FileSystem::WriteTextFile(
const std::string& path, const std::string& contents)
{
return Status(
Status::Code::UNSUPPORTED,
"Write text file operation not yet implemented " + path);
}
Status
S3FileSystem::WriteBinaryFile(
const std::string& path, const char* contents, const size_t content_len)
{
return Status(
Status::Code::UNSUPPORTED,
"Write text file operation not yet implemented " + path);
}
Status
S3FileSystem::MakeDirectory(const std::string& dir, const bool recursive)
{
return Status(
Status::Code::UNSUPPORTED,
"Make directory operation not yet implemented");
}
Status
S3FileSystem::MakeTemporaryDirectory(std::string* temp_dir)
{
return Status(
Status::Code::UNSUPPORTED,
"Make temporary directory operation not yet implemented");
}
Status
S3FileSystem::DeletePath(const std::string& path)
{
return Status(
Status::Code::UNSUPPORTED, "Delete path operation not yet implemented");
}
#endif // TRITON_ENABLE_S3
class FileSystemManager {
public:
Status GetFileSystem(
const std::string& path, std::shared_ptr<FileSystem>& file_system);
Status GetFileSystem(
FileSystemType type, std::shared_ptr<FileSystem>& file_system);
FileSystemManager();
private:
template <class CacheType, class CredentialType, class FileSystemType>
Status GetFileSystem(
const std::string& path, CacheType& cache,
std::shared_ptr<FileSystem>& file_system);
template <class CacheType, class CredentialType, class FileSystemType>
Status ReturnErrorOrReload(
const Status& load_status, const Status& error_status,
const std::string& path, CacheType& cache,
std::shared_ptr<FileSystem>& file_system);
Status LoadCredentials(bool flush_cache = false);
template <class CacheType, class CredentialType, class FileSystemType>
static void LoadCredential(
triton::common::TritonJson::Value& creds_json, const char* fs_type,
CacheType& cache);
template <class CredentialType, class FileSystemType>
static void SortCache(
std::vector<std::tuple<
std::string, CredentialType, std::shared_ptr<FileSystemType>>>&
cache);
template <class CredentialType, class FileSystemType>
static Status GetLongestMatchingNameIndex(
const std::vector<std::tuple<
std::string, CredentialType, std::shared_ptr<FileSystemType>>>& cache,
const std::string& path, size_t& idx);
std::shared_ptr<LocalFileSystem> local_fs_;
std::mutex mu_; // protect concurrent access into variables
bool is_cached_; // if name and credential is cached, lazy load file system
// cloud credential cache should be sorted in descending name length order
// [(name_long, credential, file_system), (name, ...)]
#ifdef TRITON_ENABLE_GCS
std::vector<
std::tuple<std::string, GCSCredential, std::shared_ptr<GCSFileSystem>>>
gs_cache_;
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
std::vector<
std::tuple<std::string, S3Credential, std::shared_ptr<S3FileSystem>>>
s3_cache_;
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
std::vector<
std::tuple<std::string, ASCredential, std::shared_ptr<ASFileSystem>>>
as_cache_;
#endif // TRITON_ENABLE_AZURE_STORAGE
};
FileSystemManager::FileSystemManager()
: local_fs_(new LocalFileSystem()), is_cached_(false)
{
}
Status
FileSystemManager::GetFileSystem(
const std::string& path, std::shared_ptr<FileSystem>& file_system)
{
// Check if this is a GCS path (gs://$BUCKET_NAME)
if (!path.empty() && !path.rfind("gs://", 0)) {
#ifndef TRITON_ENABLE_GCS
return Status(
Status::Code::INTERNAL,
"gs:// file-system not supported. To enable, build with "
"-DTRITON_ENABLE_GCS=ON.");
#else
return GetFileSystem<
std::vector<std::tuple<
std::string, GCSCredential, std::shared_ptr<GCSFileSystem>>>,
GCSCredential, GCSFileSystem>(path, gs_cache_, file_system);
#endif // TRITON_ENABLE_GCS
}
// Check if this is an S3 path (s3://$BUCKET_NAME)
if (!path.empty() && !path.rfind("s3://", 0)) {
#ifndef TRITON_ENABLE_S3
return Status(
Status::Code::INTERNAL,
"s3:// file-system not supported. To enable, build with "
"-DTRITON_ENABLE_S3=ON.");
#else
return GetFileSystem<
std::vector<std::tuple<
std::string, S3Credential, std::shared_ptr<S3FileSystem>>>,
S3Credential, S3FileSystem>(path, s3_cache_, file_system);
#endif // TRITON_ENABLE_S3
}
// Check if this is an Azure Storage path
if (!path.empty() && !path.rfind("as://", 0)) {
#ifndef TRITON_ENABLE_AZURE_STORAGE
return Status(
Status::Code::INTERNAL,
"as:// file-system not supported. To enable, build with "
"-DTRITON_ENABLE_AZURE_STORAGE=ON.");
#else
return GetFileSystem<
std::vector<std::tuple<
std::string, ASCredential, std::shared_ptr<ASFileSystem>>>,
ASCredential, ASFileSystem>(path, as_cache_, file_system);
#endif // TRITON_ENABLE_AZURE_STORAGE
}
// Assume path is for local filesystem
file_system = local_fs_;
return Status::Success;
}
Status
FileSystemManager::GetFileSystem(
FileSystemType type, std::shared_ptr<FileSystem>& file_system)
{
// only LOCAL and GCS are not path-dependent and can be accessed by type
switch (type) {
case FileSystemType::LOCAL:
return GetFileSystem("", file_system);
case FileSystemType::GCS:
return GetFileSystem("gs://", file_system);
case FileSystemType::S3:
return Status(
Status::Code::UNSUPPORTED,
"S3 filesystem cannot be accessed by type");
case FileSystemType::AS:
return Status(
Status::Code::UNSUPPORTED,
"AS filesystem cannot be accessed by type");
default:
return Status(Status::Code::UNSUPPORTED, "Unsupported filesystem type");
}
}
template <class CacheType, class CredentialType, class FileSystemType>
Status
FileSystemManager::GetFileSystem(
const std::string& path, CacheType& cache,
std::shared_ptr<FileSystem>& file_system)
{
const Status& cred_status = LoadCredentials();
if (cred_status.IsOk() ||
cred_status.StatusCode() == Status::Code::ALREADY_EXISTS) {
// Find credential
size_t idx;
const Status& match_status = GetLongestMatchingNameIndex(cache, path, idx);
if (!match_status.IsOk()) {
return ReturnErrorOrReload<CacheType, CredentialType, FileSystemType>(
cred_status, match_status, path, cache, file_system);
}
// Find or lazy load file system
std::shared_ptr<FileSystemType> fs = std::get<2>(cache[idx]);
if (fs == nullptr) {
std::string cred_name = std::get<0>(cache[idx]);
CredentialType cred = std::get<1>(cache[idx]);
fs = std::make_shared<FileSystemType>(path, cred);
cache[idx] = std::make_tuple(cred_name, cred, fs);
}
// Check client
const Status& client_status = fs->CheckClient(path);
if (!client_status.IsOk()) {
return ReturnErrorOrReload<CacheType, CredentialType, FileSystemType>(
cred_status, client_status, path, cache, file_system);
}
// Return client
file_system = fs;
return Status::Success;
}
return cred_status;
}
template <class CacheType, class CredentialType, class FileSystemType>
Status
FileSystemManager::ReturnErrorOrReload(
const Status& load_status, const Status& error_status,
const std::string& path, CacheType& cache,
std::shared_ptr<FileSystem>& file_system)
{
if (load_status.StatusCode() == Status::Code::ALREADY_EXISTS) {
return error_status;
}
LoadCredentials(true); // flush cache
return GetFileSystem<CacheType, CredentialType, FileSystemType>(
path, cache, file_system);
}
// return status meaning:
// - SUCCESS, "" -> loaded credential from file
// - ALREADY_EXISTS, "Cached" -> credential already loaded
Status
FileSystemManager::LoadCredentials(bool flush_cache)
{
// prevent concurrent access into class variables
std::lock_guard<std::mutex> lock(mu_);
// check if credential is already cached
if (is_cached_ && !flush_cache) {
return Status(Status::Code::ALREADY_EXISTS, "Cached");
}
const char* file_path_c_str = std::getenv("TRITON_CLOUD_CREDENTIAL_PATH");
if (file_path_c_str != nullptr) {
// Load from credential file
std::string file_path = std::string(file_path_c_str);
LOG_VERBOSE(1) << "Reading cloud credential from " << file_path;
triton::common::TritonJson::Value creds_json;
std::string cred_file_content;
RETURN_IF_ERROR(local_fs_->ReadTextFile(file_path, &cred_file_content));
RETURN_IF_ERROR(creds_json.Parse(cred_file_content));
#ifdef TRITON_ENABLE_GCS
// load GCS credentials
LoadCredential<
std::vector<std::tuple<
std::string, GCSCredential, std::shared_ptr<GCSFileSystem>>>,
GCSCredential, GCSFileSystem>(creds_json, "gs", gs_cache_);
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
// load S3 credentials
LoadCredential<
std::vector<std::tuple<
std::string, S3Credential, std::shared_ptr<S3FileSystem>>>,
S3Credential, S3FileSystem>(creds_json, "s3", s3_cache_);
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
// load AS credentials
LoadCredential<
std::vector<std::tuple<
std::string, ASCredential, std::shared_ptr<ASFileSystem>>>,
ASCredential, ASFileSystem>(creds_json, "as", as_cache_);
#endif // TRITON_ENABLE_AZURE_STORAGE
} else {
// Load from environment variables
LOG_VERBOSE(1) << "TRITON_CLOUD_CREDENTIAL_PATH environment variable is "
"not set, reading from environment variables";
#ifdef TRITON_ENABLE_GCS
// load GCS credentials
gs_cache_.clear();
gs_cache_.push_back(
std::make_tuple("", GCSCredential(), std::shared_ptr<GCSFileSystem>()));
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
// load S3 credentials
s3_cache_.clear();
s3_cache_.push_back(
std::make_tuple("", S3Credential(), std::shared_ptr<S3FileSystem>()));
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
// load AS credentials
as_cache_.clear();
as_cache_.push_back(
std::make_tuple("", ASCredential(), std::shared_ptr<ASFileSystem>()));
#endif // TRITON_ENABLE_AZURE_STORAGE
}
is_cached_ = true;
return Status::Success;
}
template <class CacheType, class CredentialType, class FileSystemType>
void
FileSystemManager::LoadCredential(
triton::common::TritonJson::Value& creds_json, const char* fs_type,
CacheType& cache)
{
cache.clear();
triton::common::TritonJson::Value creds_fs_json;
if (creds_json.Find(fs_type, &creds_fs_json)) {
std::vector<std::string> cred_names;
creds_fs_json.Members(&cred_names);
for (size_t i = 0; i < cred_names.size(); i++) {
std::string cred_name = cred_names[i];
triton::common::TritonJson::Value cred_json;
creds_fs_json.Find(cred_name.c_str(), &cred_json);
cache.push_back(std::make_tuple(
cred_name, CredentialType(cred_json),
std::shared_ptr<FileSystemType>()));
}
SortCache(cache);
}
}
template <class CredentialType, class FileSystemType>
void
FileSystemManager::SortCache(
std::vector<std::tuple<
std::string, CredentialType, std::shared_ptr<FileSystemType>>>& cache)
{
std::sort(
cache.begin(), cache.end(),
[](std::tuple<
std::string, CredentialType, std::shared_ptr<FileSystemType>>
a,
std::tuple<
std::string, CredentialType, std::shared_ptr<FileSystemType>>
b) { return std::get<0>(a).size() >= std::get<0>(b).size(); });
}
template <class CredentialType, class FileSystemType>
Status
FileSystemManager::GetLongestMatchingNameIndex(
const std::vector<std::tuple<
std::string, CredentialType, std::shared_ptr<FileSystemType>>>& cache,
const std::string& path, size_t& idx)
{
for (size_t i = 0; i < cache.size(); i++) {
if (!path.rfind(std::get<0>(cache[i]), 0)) {
idx = i;
LOG_VERBOSE(1) << "Using credential " + std::get<0>(cache[i]) +
" for path " + path;
return Status::Success;
}
}
return Status(
Status::Code::NOT_FOUND, "Cannot match credential for path " + path);
}
static FileSystemManager fsm_;
} // namespace
// FIXME: Windows support '/'? If so, the below doesn't need to change
bool
IsAbsolutePath(const std::string& path)
{
return !path.empty() && (path[0] == '/');
}
std::string
JoinPath(std::initializer_list<std::string> segments)
{
std::string joined;
for (const auto& seg : segments) {
if (joined.empty()) {
joined = seg;
} else if (IsAbsolutePath(seg)) {
if (joined[joined.size() - 1] == '/') {
joined.append(seg.substr(1));
} else {
joined.append(seg);
}
} else { // !IsAbsolutePath(seg)
if (joined[joined.size() - 1] != '/') {
joined.append("/");
}
joined.append(seg);
}
}
return joined;
}
std::string
BaseName(const std::string& path)
{
if (path.empty()) {
return path;
}
size_t last = path.size() - 1;
while ((last > 0) && (path[last] == '/')) {
last -= 1;
}
if (path[last] == '/') {
return std::string();
}
const size_t idx = path.find_last_of("/", last);
if (idx == std::string::npos) {
return path.substr(0, last + 1);
}
return path.substr(idx + 1, last - idx);
}
std::string
DirName(const std::string& path)
{
if (path.empty()) {
return path;
}
size_t last = path.size() - 1;
while ((last > 0) && (path[last] == '/')) {
last -= 1;
}
if (path[last] == '/') {
return std::string("/");
}
const size_t idx = path.find_last_of("/", last);
if (idx == std::string::npos) {
return std::string(".");
}
if (idx == 0) {
return std::string("/");
}
return path.substr(0, idx);
}
Status
FileExists(const std::string& path, bool* exists)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->FileExists(path, exists);
}
Status
IsDirectory(const std::string& path, bool* is_dir)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->IsDirectory(path, is_dir);
}
Status
FileModificationTime(const std::string& path, int64_t* mtime_ns)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->FileModificationTime(path, mtime_ns);
}
Status
GetDirectoryContents(const std::string& path, std::set<std::string>* contents)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->GetDirectoryContents(path, contents);
}
Status
GetDirectorySubdirs(const std::string& path, std::set<std::string>* subdirs)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->GetDirectorySubdirs(path, subdirs);
}
Status
GetDirectoryFiles(
const std::string& path, const bool skip_hidden_files,
std::set<std::string>* files)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
std::set<std::string> all_files;
RETURN_IF_ERROR(fs->GetDirectoryFiles(path, &all_files));
// Remove the hidden files
for (auto f : all_files) {
if ((f[0] != '.') || (!skip_hidden_files)) {
files->insert(f);
}
}
return Status::Success;
}
Status
ReadTextFile(const std::string& path, std::string* contents)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->ReadTextFile(path, contents);
}
Status
ReadTextProto(const std::string& path, google::protobuf::Message* msg)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
std::string contents;
RETURN_IF_ERROR(fs->ReadTextFile(path, &contents));
if (!google::protobuf::TextFormat::ParseFromString(contents, msg)) {
return Status(
Status::Code::INTERNAL, "failed to read text proto from " + path);
}
return Status::Success;
}
Status
LocalizePath(const std::string& path, std::shared_ptr<LocalizedPath>* localized)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->LocalizePath(path, localized);
}
Status
WriteTextProto(const std::string& path, const google::protobuf::Message& msg)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
std::string prototxt;
if (!google::protobuf::TextFormat::PrintToString(msg, &prototxt)) {
return Status(
Status::Code::INTERNAL, "failed to write text proto to " + path);
}
return fs->WriteTextFile(path, prototxt);
}
Status
WriteBinaryFile(
const std::string& path, const char* contents, const size_t content_len)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->WriteBinaryFile(path, contents, content_len);
}
Status
ReadBinaryProto(const std::string& path, google::protobuf::MessageLite* msg)
{
std::string msg_str;
RETURN_IF_ERROR(ReadTextFile(path, &msg_str));
google::protobuf::io::CodedInputStream coded_stream(
reinterpret_cast<const uint8_t*>(msg_str.c_str()), msg_str.size());
coded_stream.SetTotalBytesLimit(INT_MAX);
if (!msg->ParseFromCodedStream(&coded_stream)) {
return Status(
Status::Code::INTERNAL, "Can't parse " + path + " as binary proto");
}
return Status::Success;
}
Status
MakeDirectory(const std::string& dir, const bool recursive)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(dir, fs));
return fs->MakeDirectory(dir, recursive);
}
Status
MakeTemporaryDirectory(const FileSystemType type, std::string* temp_dir)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(type, fs));
return fs->MakeTemporaryDirectory(temp_dir);
}
Status
DeletePath(const std::string& path)
{
std::shared_ptr<FileSystem> fs;
RETURN_IF_ERROR(fsm_.GetFileSystem(path, fs));
return fs->DeletePath(path);
}
Status
GetFileSystemType(const std::string& path, FileSystemType* type)
{
if (path.empty()) {
return Status(
Status::Code::INVALID_ARG,
"Can not infer filesystem type from empty path");
}
#ifdef TRITON_ENABLE_GCS
// Check if this is a GCS path (gs://$BUCKET_NAME)
if (!path.rfind("gs://", 0)) {
*type = FileSystemType::GCS;
return Status::Success;
}
#endif // TRITON_ENABLE_GCS
#ifdef TRITON_ENABLE_S3
// Check if this is an S3 path (s3://$BUCKET_NAME)
if (!path.rfind("s3://", 0)) {
*type = FileSystemType::S3;
return Status::Success;
}
#endif // TRITON_ENABLE_S3
#ifdef TRITON_ENABLE_AZURE_STORAGE
// Check if this is an Azure Storage path
if (!path.rfind("as://", 0)) {
*type = FileSystemType::AS;
return Status::Success;
}
#endif // TRITON_ENABLE_AZURE_STORAGE
// Assume path is for local filesystem
*type = FileSystemType::LOCAL;
return Status::Success;
}
const std::string&
FileSystemTypeString(const FileSystemType type)
{
static const std::string local_str("LOCAL");
static const std::string gcs_str("GCS");
static const std::string s3_str("S3");
static const std::string as_str("AS");
static const std::string unknown_str("UNKNOWN");
switch (type) {
case FileSystemType::LOCAL:
return local_str;
case FileSystemType::GCS:
return gcs_str;
case FileSystemType::S3:
return s3_str;
case FileSystemType::AS:
return as_str;
default:
return unknown_str;
}
}
}} // namespace triton::core
// Copyright (c) 2019-2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef _WIN32
// Remove GetObject definition from windows.h, which can cause
// a naming collision when GetObject is called.
// https://github.com/Tencent/rapidjson/issues/1448
#undef GetObject
#endif // _WIN32
#include <string>
#include "google/protobuf/message.h"
#include "status.h"
namespace triton { namespace core {
enum class FileSystemType { LOCAL, GCS, S3, AS };
// This class stores the paths of local temporary files needed for loading
// models from Cloud repositories and performs necessary cleanup after the
// models are loaded.
class LocalizedPath {
public:
// Create an object for a path that is already local.
LocalizedPath(const std::string& original_path)
: original_path_(original_path)
{
}
// Create an object for a remote path. Store both the original path and the
// temporary local path.
LocalizedPath(
const std::string& original_path, const std::string& local_path)
: original_path_(original_path), local_path_(local_path)
{
}
// Destructor. Remove temporary local storage associated with the object.
// If the local path is a directory, delete the directory.
// If the local path is a file, delete the directory containing the file.
~LocalizedPath();
// Return the localized path represented by this object.
const std::string& Path() const
{
return (local_path_.empty()) ? original_path_ : local_path_;
}
// Maintain a vector of LocalizedPath that should be kept available in the
// tmp directory for the lifetime of this object
// FIXME: Remove when no longer required
std::vector<std::shared_ptr<LocalizedPath>> other_localized_path;
private:
std::string original_path_;
std::string local_path_;
};
/// Is a path an absolute path?
/// \param path The path.
/// \return true if absolute path, false if relative path.
bool IsAbsolutePath(const std::string& path);
/// Join path segments into a longer path
/// \param segments The path segments.
/// \return the path formed by joining the segments.
std::string JoinPath(std::initializer_list<std::string> segments);
/// Get the basename of a path.
/// \param path The path.
/// \return the last segment of the path.
std::string BaseName(const std::string& path);
/// Get the dirname of a path.
/// \param path The path.
/// \return all but the last segment of the path.
std::string DirName(const std::string& path);
/// Does a file or directory exist?
/// \param path The path to check for existance.
/// \param exists Returns true if file/dir exists
/// \return Error status if unable to perform the check
Status FileExists(const std::string& path, bool* exists);
/// Is a path a directory?
/// \param path The path to check.
/// \param is_dir Returns true if path represents a directory
/// \return Error status
Status IsDirectory(const std::string& path, bool* is_dir);
/// Get file modification time in nanoseconds.
/// A file is considered modified in Triton when its binary content has changed
/// including the action of replacing it with another file.
/// \param path The path.
/// \param mtime_ns Returns the file modification time. For some filesystems a
/// file/folder may not have a modification time, in that case return 0.
/// \return Error status
Status FileModificationTime(const std::string& path, int64_t* mtime_ns);
/// Get the contents of a directory.
/// \param path The directory path.
/// \param subdirs Returns the directory contents.
/// \return Error status
Status GetDirectoryContents(
const std::string& path, std::set<std::string>* contents);
/// Get the sub-directories of a path.
/// \param path The path.
/// \param subdirs Returns the names of the sub-directories.
/// \return Error status
Status GetDirectorySubdirs(
const std::string& path, std::set<std::string>* subdirs);
/// Get the files contained in a directory.
/// \param path The directory.
/// \param skip_hidden_files Ignores the hidden files in the directory.
/// \param files Returns the names of the files.
/// \return Error status
Status GetDirectoryFiles(
const std::string& path, const bool skip_hidden_files,
std::set<std::string>* files);
/// Read a text file into a string.
/// \param path The path of the file.
/// \param contents Returns the contents of the file.
/// \return Error status
Status ReadTextFile(const std::string& path, std::string* contents);
/// Create an object representing a local copy of a path.
/// \param path The path of the directory or file.
/// \param localized Returns the LocalizedPath object
/// representing the local copy of the path.
/// \return Error status
Status LocalizePath(
const std::string& path, std::shared_ptr<LocalizedPath>* localized);
/// Write a string to a file.
/// \param path The path of the file.
/// \param contents The contents to write to the file.
/// \return Error status
Status WriteTextFile(const std::string& path, const std::string& contents);
/// Write binary to a file.
/// \param path The path of the file.
/// \param contents The contents to write to the file.
/// \param content_len The size of the content.
/// \return Error status
Status WriteBinaryFile(
const std::string& path, const char* contents, const size_t content_len);
/// Read a prototext file.
/// \param path The path of the file.
/// \param msg Returns the protobuf message for the file.
/// \return Error status
Status ReadTextProto(const std::string& path, google::protobuf::Message* msg);
/// Write a prototext file.
/// \param path The path of the file.
/// \param msg The protobuf to write.
/// \return Error status
Status WriteTextProto(
const std::string& path, const google::protobuf::Message& msg);
/// Read a binary protobuf file.
/// \param path The path of the file.
/// \param msg Returns the protobuf message for the file.
/// \return Error status
Status ReadBinaryProto(
const std::string& path, google::protobuf::MessageLite* msg);
/// Create a directory of the specified path.
/// \param dir The path to the directory.
/// \param recursive Whether the parent directories will be created
/// if not exist.
/// \return Error status if the directory can't be created
Status MakeDirectory(const std::string& dir, const bool recursive);
/// Create a temporary directory of the specified filesystem type.
/// \param type The type of the filesystem.
/// \param temp_dir Returns the path to the temporary directory.
/// \return Error status
Status MakeTemporaryDirectory(const FileSystemType type, std::string* temp_dir);
/// Delete a path.
/// \param path The path to the directory or file.
/// \return Error status
Status DeletePath(const std::string& path);
/// Infer the filesystem type from the given path.
/// \param path The path to infer the filesystem type from.
/// \param type Returns the filesystem type of the path.
/// \return Error status
Status GetFileSystemType(const std::string& path, FileSystemType* type);
/// Return the string representation of the filesystem type.
/// \param type The filesystem type.
/// \return The string representation of the type.
const std::string& FileSystemTypeString(const FileSystemType type);
}} // namespace triton::core
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_parameter.h"
namespace triton { namespace core {
const void*
InferenceParameter::ValuePointer() const
{
switch (type_) {
case TRITONSERVER_PARAMETER_STRING:
return reinterpret_cast<const void*>(value_string_.c_str());
case TRITONSERVER_PARAMETER_INT:
return reinterpret_cast<const void*>(&value_int64_);
case TRITONSERVER_PARAMETER_BOOL:
return reinterpret_cast<const void*>(&value_bool_);
case TRITONSERVER_PARAMETER_BYTES:
return reinterpret_cast<const void*>(value_bytes_);
default:
break;
}
return nullptr;
}
std::ostream&
operator<<(std::ostream& out, const InferenceParameter& parameter)
{
out << "[0x" << std::addressof(parameter) << "] "
<< "name: " << parameter.Name()
<< ", type: " << TRITONSERVER_ParameterTypeString(parameter.Type())
<< ", value: ";
return out;
}
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <iostream>
#include <string>
#include "tritonserver_apis.h"
namespace triton { namespace core {
//
// An inference parameter.
//
class InferenceParameter {
public:
InferenceParameter(const char* name, const char* value)
: name_(name), type_(TRITONSERVER_PARAMETER_STRING), value_string_(value)
{
byte_size_ = value_string_.size();
}
InferenceParameter(const char* name, const int64_t value)
: name_(name), type_(TRITONSERVER_PARAMETER_INT), value_int64_(value),
byte_size_(sizeof(int64_t))
{
}
InferenceParameter(const char* name, const bool value)
: name_(name), type_(TRITONSERVER_PARAMETER_BOOL), value_bool_(value),
byte_size_(sizeof(bool))
{
}
InferenceParameter(const char* name, const void* ptr, const uint64_t size)
: name_(name), type_(TRITONSERVER_PARAMETER_BYTES), value_bytes_(ptr),
byte_size_(size)
{
}
// The name of the parameter.
const std::string& Name() const { return name_; }
// Data type of the parameter.
TRITONSERVER_ParameterType Type() const { return type_; }
// Return a pointer to the parameter, or a pointer to the data content
// if type_ is TRITONSERVER_PARAMETER_BYTES. This returned pointer must be
// cast correctly based on 'type_'.
// TRITONSERVER_PARAMETER_STRING -> const char*
// TRITONSERVER_PARAMETER_INT -> int64_t*
// TRITONSERVER_PARAMETER_BOOL -> bool*
// TRITONSERVER_PARAMETER_BYTES -> const void*
const void* ValuePointer() const;
// Return the data byte size of the parameter.
uint64_t ValueByteSize() const { return byte_size_; }
// Return the parameter value string, the return value is valid only if
// Type() returns TRITONSERVER_PARAMETER_STRING
const std::string& ValueString() const { return value_string_; }
private:
friend std::ostream& operator<<(
std::ostream& out, const InferenceParameter& parameter);
std::string name_;
TRITONSERVER_ParameterType type_;
std::string value_string_;
int64_t value_int64_;
bool value_bool_;
const void* value_bytes_;
uint64_t byte_size_;
};
std::ostream& operator<<(
std::ostream& out, const InferenceParameter& parameter);
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_request.h"
#include <algorithm>
#include <deque>
#include "model.h"
#include "model_config_utils.h"
#include "server.h"
#include "triton/common/logging.h"
#ifdef TRITON_ENABLE_TRACING
#include "cuda_utils.h"
#endif // TRITON_ENABLE_TRACING
namespace triton { namespace core {
namespace {
// Utilities for Null request feature.
TRITONSERVER_Error*
NullResponseAlloc(
TRITONSERVER_ResponseAllocator* allocator, const char* tensor_name,
size_t byte_size, TRITONSERVER_MemoryType preferred_memory_type,
int64_t preferred_memory_type_id, void* userp, void** buffer,
void** buffer_userp, TRITONSERVER_MemoryType* actual_memory_type,
int64_t* actual_memory_type_id)
{
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"unexpected allocation for null request, no output should be requested.");
}
TRITONSERVER_Error*
NullResponseRelease(
TRITONSERVER_ResponseAllocator* allocator, void* buffer, void* buffer_userp,
size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
{
return TRITONSERVER_ErrorNew(
TRITONSERVER_ERROR_INTERNAL,
"unexpected release for null request, no output should be requested.");
}
ResponseAllocator null_allocator = ResponseAllocator(
NullResponseAlloc, NullResponseRelease, nullptr /* start_fn */);
void
NullResponseComplete(
TRITONSERVER_InferenceResponse* iresponse, const uint32_t flags,
void* userp)
{
if (iresponse != nullptr) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceResponseDelete(iresponse),
"deleting null response");
}
}
void
NullRequestComplete(
TRITONSERVER_InferenceRequest* request, const uint32_t flags, void* userp)
{
if ((flags & TRITONSERVER_REQUEST_RELEASE_ALL) != 0) {
LOG_TRITONSERVER_ERROR(
TRITONSERVER_InferenceRequestDelete(request), "deleting null request");
}
}
} // namespace
InferenceRequest::InferenceRequest(
const std::shared_ptr<Model>& model, const int64_t requested_model_version)
: InferenceRequest(model.get(), requested_model_version)
{
model_shared_ = model;
}
InferenceRequest::InferenceRequest(
Model* model, const int64_t requested_model_version)
: needs_normalization_(true), model_raw_(model),
requested_model_version_(requested_model_version), flags_(0),
correlation_id_(0), batch_size_(0), timeout_us_(0), collect_stats_(true)
{
SetPriority(0);
}
const std::string&
InferenceRequest::ModelName() const
{
return model_raw_->Name();
}
int64_t
InferenceRequest::ActualModelVersion() const
{
return model_raw_->Version();
}
void
InferenceRequest::SetPriority(uint32_t p)
{
if ((p == 0) || (p > model_raw_->MaxPriorityLevel())) {
priority_ = model_raw_->DefaultPriorityLevel();
} else {
priority_ = p;
}
}
#ifdef TRITON_ENABLE_TRACING
Status
InferenceRequest::TraceInputTensors(
TRITONSERVER_InferenceTraceActivity activity, const std::string& msg)
{
const auto& inputs = this->ImmutableInputs();
TRITONSERVER_MemoryType dst_memory_type = TRITONSERVER_MEMORY_CPU;
int64_t dst_memory_type_id = 0;
for (const auto& pr : inputs) {
InferenceRequest::Input* ti = pr.second;
// input data
const std::string& name = ti->Name();
TRITONSERVER_DataType datatype = DataTypeToTriton(ti->DType());
uint64_t byte_size = ti->Data()->TotalByteSize();
const int64_t* shape = ti->ShapeWithBatchDim().data();
uint32_t dim_count = ti->ShapeWithBatchDim().size();
uint32_t buffer_count = ti->DataBufferCount();
// chunk buffer
Status status;
const void* buffer;
uint64_t buffer_size;
TRITONSERVER_MemoryType src_memory_type;
int64_t src_memory_type_id;
bool cuda_used;
if (buffer_count == 0) {
LOG_STATUS_ERROR(
status, LogRequest() +
TRITONSERVER_InferenceTraceActivityString(activity) +
": " + msg + ": tensor: " + name + ": no buffer chunk");
continue;
}
if (buffer_count == 1) {
status = ti->DataBuffer(
0, &buffer, &buffer_size, &src_memory_type, &src_memory_type_id);
if (!status.IsOk()) {
LOG_STATUS_ERROR(
status, LogRequest() +
TRITONSERVER_InferenceTraceActivityString(activity) +
": " + msg + ": tensor: " + name +
": fail to get data buffer: " + status.Message());
return status;
}
if (buffer_size != byte_size) {
LOG_STATUS_ERROR(
status,
LogRequest() + TRITONSERVER_InferenceTraceActivityString(activity) +
": " + msg + ": tensor: " + name + ": truncated buffer");
continue;
}
INFER_TRACE_TENSOR_ACTIVITY(
this->trace_, activity, name.c_str(), datatype,
const_cast<void*>(buffer), buffer_size, shape, dim_count,
src_memory_type, src_memory_type_id);
continue;
}
// input buffer
std::vector<char> in_buffer(byte_size);
char* base = in_buffer.data();
size_t offset = 0;
for (uint32_t b = 0; b < buffer_count; ++b) {
status = ti->DataBuffer(
b, &buffer, &buffer_size, &src_memory_type, &src_memory_type_id);
if (!status.IsOk()) {
LOG_STATUS_ERROR(
status, LogRequest() +
TRITONSERVER_InferenceTraceActivityString(activity) +
": " + msg + ": tensor: " + name +
": fail to get data buffer: " + status.Message());
return status;
}
status = CopyBuffer(
"InferenceRequest TraceInputTensors", src_memory_type,
src_memory_type_id, dst_memory_type, dst_memory_type_id, buffer_size,
buffer, base + offset, nullptr, &cuda_used);
if (!status.IsOk()) {
LOG_STATUS_ERROR(
status, LogRequest() +
TRITONSERVER_InferenceTraceActivityString(activity) +
": " + msg + ": tensor: " + name +
": fail to copy buffer: " + status.Message());
return status;
}
offset += buffer_size;
}
INFER_TRACE_TENSOR_ACTIVITY(
this->trace_, activity, name.c_str(), datatype,
static_cast<void*>(base), byte_size, shape, dim_count, dst_memory_type,
dst_memory_type_id);
}
return Status::Success;
}
#endif // TRITON_ENABLE_TRACING
Status
InferenceRequest::OutputBufferProperties(
const char* name, size_t* byte_size, TRITONSERVER_MemoryType* memory_type,
int64_t* memory_type_id)
{
const auto allocator = response_factory_->Allocator();
if ((allocator == nullptr) || (allocator->QueryFn() == nullptr)) {
return Status(
Status::Code::UNAVAILABLE,
(LogRequest() + "Output properties are not available").c_str());
} else {
RETURN_IF_TRITONSERVER_ERROR(allocator->QueryFn()(
reinterpret_cast<TRITONSERVER_ResponseAllocator*>(
const_cast<ResponseAllocator*>(allocator)),
response_factory_->AllocatorUserp(), name, byte_size, memory_type,
memory_type_id));
}
return Status::Success;
}
Status
InferenceRequest::Run(std::unique_ptr<InferenceRequest>& request)
{
return request->model_raw_->Enqueue(request);
}
void
InferenceRequest::RespondIfError(
std::unique_ptr<InferenceRequest>& request, const Status& status,
const bool release_request)
{
if (status.IsOk()) {
return;
}
// Use the response factory to create a response, set the status,
// and send it. If something goes wrong all we can do is log the
// error. Because this is sending an error we assume that this is
// the last response for the request and so set the FINAL flag.
std::unique_ptr<InferenceResponse> response;
LOG_STATUS_ERROR(
request->response_factory_->CreateResponse(&response),
(request->LogRequest() + "failed to create error response").c_str());
LOG_STATUS_ERROR(
InferenceResponse::SendWithStatus(
std::move(response), TRITONSERVER_RESPONSE_COMPLETE_FINAL, status),
(request->LogRequest() + "failed to send error response").c_str());
// If releasing the request then invoke the release callback which
// gives ownership to the callback. So can't access 'request' after
// this point.
if (release_request) {
InferenceRequest::Release(
std::move(request), TRITONSERVER_REQUEST_RELEASE_ALL);
}
}
void
InferenceRequest::RespondIfError(
std::vector<std::unique_ptr<InferenceRequest>>& requests,
const Status& status, const bool release_requests)
{
if (status.IsOk()) {
return;
}
for (auto& request : requests) {
RespondIfError(request, status, release_requests);
}
}
void
InferenceRequest::Release(
std::unique_ptr<InferenceRequest>&& request, const uint32_t release_flags)
{
// Invoke the release callbacks added internally before releasing the
// request to user provided callback.
for (auto it = request->release_callbacks_.rbegin();
it != request->release_callbacks_.rend(); it++) {
(*it)();
}
request->release_callbacks_.clear();
#ifdef TRITON_ENABLE_TRACING
// If tracing then record request end and release the trace.
// This must be before the request callback to ensure the trace
// is properly layered, as the request may be nested in an ensemble
// and the callback may interact with upper level trace.
if (request->trace_ != nullptr) {
request->trace_->ReportNow(TRITONSERVER_TRACE_REQUEST_END);
request->ReleaseTrace();
}
#endif // TRITON_ENABLE_TRACING
void* userp = request->release_userp_;
auto& release_fn = request->release_fn_;
release_fn(
reinterpret_cast<TRITONSERVER_InferenceRequest*>(request.release()),
release_flags, userp);
}
InferenceRequest*
InferenceRequest::CopyAsNull(const InferenceRequest& from)
{
// Create a copy of 'from' request with artifical inputs and no requested
// outputs. Maybe more efficient to share inputs and other metadata,
// but that binds the Null request with 'from' request's lifecycle.
std::unique_ptr<InferenceRequest> lrequest(
new InferenceRequest(from.model_raw_, from.requested_model_version_));
lrequest->needs_normalization_ = false;
lrequest->batch_size_ = from.batch_size_;
lrequest->collect_stats_ = false;
// Three passes: first to construct input for the shape tensors inputs, second
// to obtain the max input byte size for allocating a large enough buffer for
// all non shape tensor inputs; third to construct the inputs for these
// tensors.
// First pass
for (const auto& input : from.OriginalInputs()) {
// Handle only shape tensors in this pass
if (!input.second.IsShapeTensor()) {
continue;
}
// Prepare the memory to hold input data
size_t byte_size = input.second.Data()->TotalByteSize();
auto mem_type = TRITONSERVER_MEMORY_CPU;
int64_t mem_id = 0;
std::shared_ptr<MutableMemory> data =
std::make_shared<AllocatedMemory>(byte_size, mem_type, mem_id);
// Get the source buffer. Assumes shape tensors be in a single buffer on the
// CPU
const auto& from_data = input.second.Data();
size_t from_data_byte_size;
TRITONSERVER_MemoryType from_data_memory_type;
int64_t from_data_memory_id;
const char* from_data_buffer = from_data->BufferAt(
0 /* idx */, &from_data_byte_size, &from_data_memory_type,
&from_data_memory_id);
if (from_data_byte_size != byte_size) {
LOG_WARNING
<< lrequest->LogRequest()
<< "The byte size of shape tensor to be copied does not match";
}
// Copy the shape values to the input buffer
std::memcpy(data->MutableBuffer(), from_data_buffer, from_data_byte_size);
Input* new_input;
lrequest->AddOriginalInput(
input.first, input.second.DType(), input.second.Shape(), &new_input);
// Must normalize shape here...
*new_input->MutableShape() = input.second.Shape();
*new_input->MutableShapeWithBatchDim() = input.second.ShapeWithBatchDim();
new_input->SetData(data);
}
// Second pass
size_t max_byte_size = 0;
size_t max_str_byte_size = 0;
const std::string* max_input_name;
for (const auto& input : from.OriginalInputs()) {
// Skip shape tensors in this pass
if (input.second.IsShapeTensor()) {
continue;
}
if (input.second.DType() == inference::DataType::TYPE_STRING) {
int64_t element_count =
triton::common::GetElementCount(input.second.Shape());
size_t str_byte_size = static_cast<size_t>(4 * element_count);
max_str_byte_size = std::max(str_byte_size, max_str_byte_size);
if (str_byte_size > max_byte_size) {
max_byte_size = str_byte_size;
max_input_name = &(input.first);
}
} else {
if (input.second.Data()->TotalByteSize() >= max_byte_size) {
max_byte_size = input.second.Data()->TotalByteSize();
max_input_name = &(input.first);
}
}
}
// Third pass
// [DLIS-1268] should use one growable static buffer for all null requests
auto mem_type = TRITONSERVER_MEMORY_CPU;
int64_t mem_id = 0;
std::shared_ptr<MutableMemory> data =
std::make_shared<AllocatedMemory>(max_byte_size, mem_type, mem_id);
auto data_base = data->BufferAt(0, &max_byte_size, &mem_type, &mem_id);
// Zero initialization is only required when there is a TYPE_BYTES tensor in
// the request. Only set the required number of bytes to zero.
if (max_str_byte_size > 0) {
std::fill(
data->MutableBuffer(), data->MutableBuffer() + max_str_byte_size, 0);
}
for (const auto& input : from.OriginalInputs()) {
// skip shape tensors in this pass
if (input.second.IsShapeTensor()) {
continue;
}
Input* new_input;
lrequest->AddOriginalInput(
input.first, input.second.DType(), input.second.Shape(), &new_input);
// Must normalize shape here...
*new_input->MutableShape() = input.second.Shape();
*new_input->MutableShapeWithBatchDim() = input.second.ShapeWithBatchDim();
// Note that the input that have max byte size will be responsible for
// holding the artifical data, while other inputs will hold a reference to
// it with byte size that matches 'from'
if (input.first == *max_input_name) {
new_input->SetData(data);
} else {
if (inference::DataType::TYPE_STRING == input.second.DType()) {
new_input->AppendData(
data_base,
triton::common::GetElementCount(input.second.Shape()) * 4, mem_type,
mem_id);
} else {
new_input->AppendData(
data_base, input.second.Data()->TotalByteSize(), mem_type, mem_id);
}
}
}
// No outputs were requested and thus there should be no allocations.
lrequest->SetResponseCallback(
&null_allocator, nullptr, NullResponseComplete, nullptr);
lrequest->SetReleaseCallback(NullRequestComplete, nullptr);
// Must normalize inputs here...
for (auto& pr : lrequest->original_inputs_) {
lrequest->inputs_.emplace(
std::make_pair(pr.second.Name(), std::addressof(pr.second)));
}
return lrequest.release();
}
Status
InferenceRequest::MutableOriginalInput(
const std::string& name, InferenceRequest::Input** input)
{
auto itr = original_inputs_.find(name);
if (itr == original_inputs_.end()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + name + "' does not exist in request");
}
*input = &(itr->second);
return Status::Success;
}
Status
InferenceRequest::ImmutableInput(
const std::string& name, const InferenceRequest::Input** input) const
{
auto itr = inputs_.find(name);
if (itr == inputs_.end()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + name + "' does not exist in request");
}
*input = itr->second;
return Status::Success;
}
Status
InferenceRequest::AddOriginalInput(
const std::string& name, const inference::DataType datatype,
const int64_t* shape, const uint64_t dim_count,
InferenceRequest::Input** input)
{
const auto& pr = original_inputs_.emplace(
std::piecewise_construct, std::forward_as_tuple(name),
std::forward_as_tuple(name, datatype, shape, dim_count));
if (!pr.second) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + name + "' already exists in request");
}
if (input != nullptr) {
*input = std::addressof(pr.first->second);
}
needs_normalization_ = true;
return Status::Success;
}
Status
InferenceRequest::AddOriginalInput(
const std::string& name, const inference::DataType datatype,
const std::vector<int64_t>& shape, InferenceRequest::Input** input)
{
return AddOriginalInput(name, datatype, &shape[0], shape.size(), input);
}
Status
InferenceRequest::AddRawInput(
const std::string& name, InferenceRequest::Input** input)
{
if (original_inputs_.size() != 0) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "raw input '" + name +
"' can't be added to request with other inputs");
}
const auto& pr = original_inputs_.emplace(
std::piecewise_construct, std::forward_as_tuple(name),
std::forward_as_tuple());
if (!pr.second) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + name + "' already exists in request");
}
if (input != nullptr) {
*input = std::addressof(pr.first->second);
}
raw_input_name_ = name;
needs_normalization_ = true;
return Status::Success;
}
Status
InferenceRequest::RemoveOriginalInput(const std::string& name)
{
if (original_inputs_.erase(name) != 1) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + name + "' does not exist in request");
}
if (name == raw_input_name_) {
raw_input_name_.clear();
}
needs_normalization_ = true;
return Status::Success;
}
Status
InferenceRequest::RemoveAllOriginalInputs()
{
original_inputs_.clear();
raw_input_name_.clear();
needs_normalization_ = true;
return Status::Success;
}
Status
InferenceRequest::AddOverrideInput(
const std::string& name, const inference::DataType datatype,
const int64_t batch_size, const std::vector<int64_t>& shape,
std::shared_ptr<InferenceRequest::Input>* input)
{
std::shared_ptr<Input> i = std::make_shared<Input>(name, datatype, shape);
*(i->MutableShape()) = i->OriginalShape();
if (batch_size > 0) {
*(i->MutableShapeWithBatchDim()) = {batch_size};
i->MutableShapeWithBatchDim()->insert(
i->MutableShapeWithBatchDim()->end(), i->OriginalShape().begin(),
i->OriginalShape().end());
} else {
*(i->MutableShapeWithBatchDim()) = i->OriginalShape();
}
RETURN_IF_ERROR(AddOverrideInput(i));
if (input != nullptr) {
*input = std::move(i);
}
return Status::Success;
}
Status
InferenceRequest::AddOverrideInput(
const std::shared_ptr<InferenceRequest::Input>& input)
{
LOG_VERBOSE(1) << LogRequest() << "adding input override for "
<< input->Name() << ": " << *this;
const auto& pr =
override_inputs_.emplace(std::make_pair(input->Name(), input));
if (!pr.second) {
pr.first->second = input;
}
// Add or replace this override in the inputs...
const auto res = inputs_.emplace(std::make_pair(input->Name(), input.get()));
if (!res.second) {
res.first->second = input.get();
}
LOG_VERBOSE(1) << LogRequest() << "added input override for " << input->Name()
<< ": " << *this;
return Status::Success;
}
Status
InferenceRequest::AddOriginalRequestedOutput(const std::string& name)
{
original_requested_outputs_.insert(name);
needs_normalization_ = true;
return Status::Success;
}
Status
InferenceRequest::LoadInputStates()
{
// Add the input states to the inference request.
if (sequence_states_ != nullptr) {
if (sequence_states_->IsNullRequest()) {
sequence_states_ =
SequenceStates::CopyAsNull(sequence_states_->NullSequenceStates());
}
for (auto& input_state_pair : sequence_states_->InputStates()) {
auto& input_state = input_state_pair.second;
std::shared_ptr<InferenceRequest::Input> input =
std::make_shared<InferenceRequest::Input>(
input_state->Name(), input_state->DType(), input_state->Shape());
*input->MutableShapeWithBatchDim() = input_state->Shape();
input->SetData(input_state->Data());
AddOverrideInput(input);
}
}
return Status::Success;
}
Status
InferenceRequest::RemoveOriginalRequestedOutput(const std::string& name)
{
original_requested_outputs_.erase(name);
needs_normalization_ = true;
return Status::Success;
}
Status
InferenceRequest::RemoveAllOriginalRequestedOutputs()
{
original_requested_outputs_.clear();
needs_normalization_ = true;
return Status::Success;
}
Status
InferenceRequest::PrepareForInference()
{
// Remove override inputs as those are added during any previous
// inference execution.
inputs_.clear();
override_inputs_.clear();
// Renormalize if anything has changed in the inference request in a
// way that could impact renormalization.
if (needs_normalization_) {
RETURN_IF_ERROR(Normalize());
needs_normalization_ = false;
}
// Initially show the actual inputs to be only the original
// inputs. If overrides are added later they will be added to
// 'inputs_'.
for (auto& pr : original_inputs_) {
inputs_.emplace(
std::make_pair(pr.second.Name(), std::addressof(pr.second)));
}
// Clear the timestamps
queue_start_ns_ = 0;
batcher_start_ns_ = 0;
#ifdef TRITON_ENABLE_STATS
request_start_ns_ = 0;
#endif // TRITON_ENABLE_STATS
LOG_VERBOSE(1) << LogRequest() << "prepared: " << *this;
return Status::Success;
}
Status
InferenceRequest::Normalize()
{
const inference::ModelConfig& model_config = model_raw_->Config();
// Fill metadata for raw input
if (!raw_input_name_.empty()) {
const bool has_multiple_inputs =
(original_inputs_.size() != 1) || (model_config.input_size() != 1);
if (has_multiple_inputs) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "Raw request must only have 1 input (found " +
std::to_string(original_inputs_.size()) +
") to be deduced but got " +
std::to_string(model_config.input_size()) + " inputs in '" +
ModelName() + "' model configuration");
}
auto it = original_inputs_.begin();
if (raw_input_name_ != it->first) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "Unexpected reference name for raw input '" +
raw_input_name_ + "' got '" + it->first + "'");
}
const auto& config_input = model_config.input(0);
auto& raw_input = it->second;
std::vector<int64_t> shape;
if (model_config.max_batch_size() != 0) {
shape.emplace_back(1);
}
int64_t dynamic_axis = -1;
size_t element_cnt = 1;
for (const auto& dim : config_input.dims()) {
if (dim == triton::common::WILDCARD_DIM) {
if (dynamic_axis != -1) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "The shape of the raw input '" +
config_input.name() +
"' can not be deduced because there are more than one "
"variable-sized dimension");
}
dynamic_axis = shape.size();
} else {
element_cnt *= (size_t)dim;
}
shape.emplace_back(dim);
}
if ((config_input.data_type() == inference::DataType::TYPE_STRING)) {
const bool has_one_element = (dynamic_axis == -1) && (element_cnt == 1);
if (!has_one_element) {
return Status(
Status::Code::INVALID_ARG, LogRequest() +
"For BYTE datatype raw input, the "
"model must have input shape [1]");
}
// In the case of BYTE data type, we will prepend the byte size to follow
// the Triton convention.
raw_input_size_ = raw_input.Data()->TotalByteSize();
RETURN_IF_ERROR(raw_input.PrependData(
&raw_input_size_, sizeof(uint32_t), TRITONSERVER_MEMORY_CPU, 0));
// Limit the BYTE raw input not to have host policy specific input for
// simplicity, such case won't happen given the current protocol spec.
// Will need to extend Input::PrependData() if needed.
if (!raw_input.HostPolicyData().empty()) {
return Status(
Status::Code::INVALID_ARG, LogRequest() +
"Raw input with data associated "
"with a host policy setting is not "
"currently supported");
}
} else if (dynamic_axis != -1) {
shape[dynamic_axis] =
raw_input.Data()->TotalByteSize() / element_cnt /
triton::common::GetDataTypeByteSize(config_input.data_type());
}
raw_input.SetMetadata(config_input.name(), config_input.data_type(), shape);
}
// Initialize the requested outputs to be used during inference. If
// original_requested_outputs_ is empty assume all outputs specified
// in model config are being requested.
requested_outputs_.clear();
if (original_requested_outputs_.size() == 0) {
for (const auto& output : model_config.output()) {
requested_outputs_.insert(output.name());
}
} else {
// Validate if the original requested output name exists in the
// model configuration.
for (const auto& output_name : original_requested_outputs_) {
const inference::ModelOutput* output_config;
RETURN_IF_ERROR(model_raw_->GetOutput(output_name, &output_config));
}
}
// Make sure that the request is providing the number of inputs
// as is expected by the model.
if ((original_inputs_.size() > (size_t)model_config.input_size()) ||
(original_inputs_.size() < model_raw_->RequiredInputCount())) {
// If no input is marked as optional, then use exact match error message
// for consistency / backward compatibility
if ((size_t)model_config.input_size() == model_raw_->RequiredInputCount()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected " +
std::to_string(model_config.input_size()) + " inputs but got " +
std::to_string(original_inputs_.size()) + " inputs for model '" +
ModelName() + "'");
} else {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "expected number of inputs between " +
std::to_string(model_raw_->RequiredInputCount()) + " and " +
std::to_string(model_config.input_size()) + " but got " +
std::to_string(original_inputs_.size()) + " inputs for model '" +
ModelName() + "'");
}
}
// Determine the batch size and shape of each input.
if (model_config.max_batch_size() == 0) {
// Model does not support Triton-style batching so set as
// batch-size 0 and leave the tensor shapes as they are.
batch_size_ = 0;
for (auto& pr : original_inputs_) {
auto& input = pr.second;
*input.MutableShape() = input.OriginalShape();
}
} else {
// Model does support Triton-style batching so each input tensor
// must have the same first dimension which is the batch
// size. Adjust the shape of the input tensors to remove the batch
// dimension.
batch_size_ = 0;
for (auto& pr : original_inputs_) {
auto& input = pr.second;
// For a shape tensor, keep the tensor's shape as it is and mark
// that the input is a shape tensor.
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(input.Name(), &input_config));
if (input_config->is_shape_tensor()) {
*input.MutableShape() = input.OriginalShape();
input.SetIsShapeTensor(true);
continue;
}
if (input.OriginalShape().size() == 0) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' has no shape but model requires batch dimension for '" +
ModelName() + "'");
}
if (batch_size_ == 0) {
batch_size_ = input.OriginalShape()[0];
} else if (input.OriginalShape()[0] != batch_size_) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "input '" + input.Name() +
"' batch size does not match other inputs for '" + ModelName() +
"'");
}
input.MutableShape()->assign(
input.OriginalShape().begin() + 1, input.OriginalShape().end());
}
}
// Make sure request batch-size doesn't exceed what is supported by
// the model.
if ((int)batch_size_ > model_config.max_batch_size()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference request batch-size must be <= " +
std::to_string(model_config.max_batch_size()) + " for '" +
ModelName() + "'");
}
// Verify that each input shape is valid for the model, make
// adjustments for reshapes and find the total tensor size.
for (auto& pr : original_inputs_) {
const inference::ModelInput* input_config;
RETURN_IF_ERROR(model_raw_->GetInput(pr.second.Name(), &input_config));
auto& input = pr.second;
auto shape = input.MutableShape();
if (input.DType() != input_config->data_type()) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "inference input data-type is '" +
std::string(
triton::common::DataTypeToProtocolString(input.DType())) +
"', model expects '" +
std::string(triton::common::DataTypeToProtocolString(
input_config->data_type())) +
"' for '" + ModelName() + "'");
}
// Validate input shape
{
bool match_config = true;
const auto& config_dims = input_config->dims();
const auto& input_dims = *shape;
if (config_dims.size() != (int64_t)input_dims.size()) {
match_config = false;
} else {
for (int i = 0; i < config_dims.size(); ++i) {
if (input_dims[i] == triton::common::WILDCARD_DIM) {
return Status(
Status::Code::INVALID_ARG,
LogRequest() +
"All input dimensions should be specified for input '" +
pr.first + "' for model '" + ModelName() + "', got " +
triton::common::DimsListToString(input.OriginalShape()));
} else if (
(config_dims[i] != triton::common::WILDCARD_DIM) &&
(config_dims[i] != input_dims[i])) {
match_config = false;
break;
}
}
}
if (!match_config) {
triton::common::DimsList full_dims;
if (model_config.max_batch_size() > 0) {
full_dims.Add(triton::common::WILDCARD_DIM);
}
for (int i = 0; i < input_config->dims_size(); ++i) {
full_dims.Add(input_config->dims(i));
}
return Status(
Status::Code::INVALID_ARG,
LogRequest() + "unexpected shape for input '" + pr.first +
"' for model '" + ModelName() + "'. Expected " +
triton::common::DimsListToString(full_dims) + ", got " +
triton::common::DimsListToString(input.OriginalShape()));
}
}
// If there is a reshape for this input then adjust them to
// match the reshape. As reshape may have variable-size
// dimensions, we need to record corresponding value so that we
// can set the value correctly for reshape.
if (input_config->has_reshape()) {
std::deque<int64_t> variable_size_values;
for (int64_t idx = 0; idx < input_config->dims_size(); idx++) {
if (input_config->dims(idx) == -1) {
variable_size_values.push_back((*shape)[idx]);
}
}
shape->clear();
for (const auto& dim : input_config->reshape().shape()) {
if (dim == -1) {
shape->push_back(variable_size_values.front());
variable_size_values.pop_front();
} else {
shape->push_back(dim);
}
}
}
// Create shape with batch dimension.
// FIXME, should not need this!!
if (batch_size_ == 0) {
*input.MutableShapeWithBatchDim() = *shape;
} else {
input.MutableShapeWithBatchDim()->clear();
input.MutableShapeWithBatchDim()->push_back(batch_size_);
for (int64_t d : *shape) {
input.MutableShapeWithBatchDim()->push_back(d);
}
}
}
return Status::Success;
}
#ifdef TRITON_ENABLE_STATS
void
InferenceRequest::ReportStatistics(
MetricModelReporter* metric_reporter, bool success,
const uint64_t compute_start_ns, const uint64_t compute_input_end_ns,
const uint64_t compute_output_start_ns, const uint64_t compute_end_ns)
{
if (!collect_stats_) {
return;
}
#ifdef TRITON_ENABLE_TRACING
if (trace_ != nullptr) {
trace_->Report(TRITONSERVER_TRACE_COMPUTE_START, compute_start_ns);
trace_->Report(TRITONSERVER_TRACE_COMPUTE_INPUT_END, compute_input_end_ns);
trace_->Report(
TRITONSERVER_TRACE_COMPUTE_OUTPUT_START, compute_output_start_ns);
trace_->Report(TRITONSERVER_TRACE_COMPUTE_END, compute_end_ns);
}
#endif // TRITON_ENABLE_TRACING
INFER_STATS_DECL_TIMESTAMP(request_end_ns);
if (success) {
model_raw_->MutableStatsAggregator()->UpdateSuccess(
metric_reporter, std::max(1U, batch_size_), request_start_ns_,
queue_start_ns_, compute_start_ns, compute_input_end_ns,
compute_output_start_ns, compute_end_ns, request_end_ns);
if (secondary_stats_aggregator_ != nullptr) {
secondary_stats_aggregator_->UpdateSuccess(
nullptr /* metric_reporter */, std::max(1U, batch_size_),
request_start_ns_, queue_start_ns_, compute_start_ns,
compute_input_end_ns, compute_output_start_ns, compute_end_ns,
request_end_ns);
}
} else {
model_raw_->MutableStatsAggregator()->UpdateFailure(
metric_reporter, request_start_ns_, request_end_ns);
if (secondary_stats_aggregator_ != nullptr) {
secondary_stats_aggregator_->UpdateFailure(
nullptr /* metric_reporter */, request_start_ns_, request_end_ns);
}
}
}
void
InferenceRequest::ReportStatisticsWithDuration(
MetricModelReporter* metric_reporter, bool success,
const uint64_t compute_start_ns, const uint64_t compute_input_duration_ns,
const uint64_t compute_infer_duration_ns,
const uint64_t compute_output_duration_ns)
{
if (!collect_stats_) {
return;
}
INFER_STATS_DECL_TIMESTAMP(request_end_ns);
if (success) {
model_raw_->MutableStatsAggregator()->UpdateSuccessWithDuration(
metric_reporter, std::max(1U, batch_size_), request_start_ns_,
queue_start_ns_, compute_start_ns, request_end_ns,
compute_input_duration_ns, compute_infer_duration_ns,
compute_output_duration_ns);
if (secondary_stats_aggregator_ != nullptr) {
secondary_stats_aggregator_->UpdateSuccessWithDuration(
nullptr /* metric_reporter */, std::max(1U, batch_size_),
request_start_ns_, queue_start_ns_, compute_start_ns, request_end_ns,
compute_input_duration_ns, compute_infer_duration_ns,
compute_output_duration_ns);
}
} else {
model_raw_->MutableStatsAggregator()->UpdateFailure(
metric_reporter, request_start_ns_, request_end_ns);
if (secondary_stats_aggregator_ != nullptr) {
secondary_stats_aggregator_->UpdateFailure(
nullptr /* metric_reporter */, request_start_ns_, request_end_ns);
}
}
}
void
InferenceRequest::ReportStatisticsCacheHit(MetricModelReporter* metric_reporter)
{
// Capture end of request time
INFER_STATS_DECL_TIMESTAMP(request_end_ns);
if (cache_lookup_start_ns_ >= cache_lookup_end_ns_) {
LOG_WARNING << LogRequest()
<< "Cache lookup timestamps were not set correctly. Cache "
"lookup duration stats may be incorrect.";
}
const uint64_t cache_lookup_duration_ns =
cache_lookup_end_ns_ - cache_lookup_start_ns_;
// Cache hit is always success
model_raw_->MutableStatsAggregator()->UpdateSuccessCacheHit(
metric_reporter, std::max(1U, batch_size_), request_start_ns_,
queue_start_ns_, cache_lookup_start_ns_, request_end_ns,
cache_lookup_duration_ns);
if (secondary_stats_aggregator_ != nullptr) {
secondary_stats_aggregator_->UpdateSuccessCacheHit(
nullptr /* metric_reporter */, std::max(1U, batch_size_),
request_start_ns_, queue_start_ns_, cache_lookup_start_ns_,
request_end_ns, cache_lookup_duration_ns);
}
}
void
InferenceRequest::ReportStatisticsCacheMiss(
MetricModelReporter* metric_reporter)
{
if (cache_lookup_start_ns_ >= cache_lookup_end_ns_) {
LOG_WARNING << LogRequest()
<< "Cache lookup timestamps were not set correctly. Cache "
"lookup duration stats may be incorrect.";
}
if (cache_insertion_start_ns_ >= cache_insertion_end_ns_) {
LOG_WARNING << LogRequest()
<< "Cache insertion timestamps were not set correctly. Cache "
"insertion duration stats may be incorrect.";
}
const uint64_t cache_lookup_duration_ns =
cache_lookup_end_ns_ - cache_lookup_start_ns_;
const uint64_t cache_insertion_duration_ns =
cache_insertion_end_ns_ - cache_insertion_start_ns_;
model_raw_->MutableStatsAggregator()->UpdateSuccessCacheMiss(
metric_reporter, cache_lookup_duration_ns, cache_insertion_duration_ns);
if (secondary_stats_aggregator_ != nullptr) {
secondary_stats_aggregator_->UpdateSuccessCacheMiss(
nullptr /* metric_reporter */, cache_lookup_duration_ns,
cache_insertion_duration_ns);
}
}
#endif // TRITON_ENABLE_STATS
//
// Input
//
InferenceRequest::Input::Input()
: is_shape_tensor_(false), data_(new MemoryReference),
has_host_policy_specific_data_(false)
{
}
InferenceRequest::Input::Input(
const std::string& name, const inference::DataType datatype,
const int64_t* shape, const uint64_t dim_count)
: name_(name), datatype_(datatype),
original_shape_(shape, shape + dim_count), is_shape_tensor_(false),
data_(new MemoryReference), has_host_policy_specific_data_(false)
{
}
InferenceRequest::Input::Input(
const std::string& name, const inference::DataType datatype,
const std::vector<int64_t>& shape)
: name_(name), datatype_(datatype), original_shape_(shape),
is_shape_tensor_(false), data_(new MemoryReference),
has_host_policy_specific_data_(false)
{
}
void
InferenceRequest::Input::SetMetadata(
const std::string& name, const inference::DataType& dt,
const std::vector<int64_t>& shape)
{
name_ = name;
datatype_ = dt;
original_shape_ = shape;
}
Status
InferenceRequest::Input::SetIsShapeTensor(const bool is_shape_tensor)
{
is_shape_tensor_ = is_shape_tensor;
return Status::Success;
}
const std::shared_ptr<Memory>&
InferenceRequest::Input::Data(const std::string& host_policy_name) const
{
auto device_data = host_policy_data_map_.find(host_policy_name);
if (device_data == host_policy_data_map_.end()) {
// Fall back on default data if there is no data that has been added for
// this host policy
return data_;
}
return device_data->second;
}
Status
InferenceRequest::Input::AppendData(
const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
{
if (byte_size > 0) {
std::static_pointer_cast<MemoryReference>(data_)->AddBuffer(
static_cast<const char*>(base), byte_size, memory_type, memory_type_id);
}
return Status::Success;
}
Status
InferenceRequest::Input::AppendDataWithBufferAttributes(
const void* base, BufferAttributes* buffer_attributes)
{
if (buffer_attributes->ByteSize() > 0) {
std::static_pointer_cast<MemoryReference>(data_)->AddBuffer(
static_cast<const char*>(base), buffer_attributes);
}
return Status::Success;
}
Status
InferenceRequest::Input::AppendDataWithHostPolicy(
const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id, const char* host_policy_name)
{
auto device_data = host_policy_data_map_.find(host_policy_name);
has_host_policy_specific_data_ = true;
if (device_data == host_policy_data_map_.end()) {
auto insert_pair = host_policy_data_map_.insert(
std::make_pair(std::string(host_policy_name), new MemoryReference));
device_data = insert_pair.first;
}
if (byte_size > 0) {
std::static_pointer_cast<MemoryReference>(device_data->second)
->AddBuffer(
static_cast<const char*>(base), byte_size, memory_type,
memory_type_id);
}
return Status::Success;
}
Status
InferenceRequest::Input::PrependData(
const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id)
{
if (byte_size > 0) {
std::static_pointer_cast<MemoryReference>(data_)->AddBufferFront(
static_cast<const char*>(base), byte_size, memory_type, memory_type_id);
}
return Status::Success;
}
Status
InferenceRequest::Input::SetData(const std::shared_ptr<Memory>& data)
{
if (data_->TotalByteSize() != 0) {
return Status(
Status::Code::INVALID_ARG,
"input '" + name_ + "' already has data, can't overwrite");
}
data_ = data;
return Status::Success;
}
Status
InferenceRequest::Input::SetData(
const std::string& host_policy_name, const std::shared_ptr<Memory>& data)
{
if (host_policy_data_map_.find(host_policy_name) !=
host_policy_data_map_.end()) {
return Status(
Status::Code::INVALID_ARG, "input '" + name_ +
"' already has data for host policy '" +
host_policy_name + "', can't overwrite");
}
host_policy_data_map_.emplace(host_policy_name, data);
return Status::Success;
}
Status
InferenceRequest::Input::RemoveAllData()
{
data_ = std::make_shared<MemoryReference>();
host_policy_data_map_.clear();
has_host_policy_specific_data_ = false;
return Status::Success;
}
Status
InferenceRequest::Input::DataBuffer(
const size_t idx, const void** base, size_t* byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) const
{
*base = data_->BufferAt(idx, byte_size, memory_type, memory_type_id);
return Status::Success;
}
Status
InferenceRequest::Input::DataBufferAttributes(
const size_t idx, const void** base,
BufferAttributes** buffer_attributes) const
{
*base = data_->BufferAt(idx, buffer_attributes);
return Status::Success;
}
Status
InferenceRequest::Input::DataBufferForHostPolicy(
const size_t idx, const void** base, size_t* byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id,
const std::string& host_policy_name) const
{
auto device_data = host_policy_data_map_.find(host_policy_name);
if (device_data == host_policy_data_map_.end()) {
// Return data buffer if there is no host-policy specific buffer available
*base = data_->BufferAt(idx, byte_size, memory_type, memory_type_id);
} else {
*base = device_data->second->BufferAt(
idx, byte_size, memory_type, memory_type_id);
}
return Status::Success;
}
size_t
InferenceRequest::Input::DataBufferCountForHostPolicy(
const std::string& host_policy_name) const
{
auto policy_data = host_policy_data_map_.find(host_policy_name);
if (policy_data != host_policy_data_map_.end()) {
return policy_data->second->BufferCount();
}
return data_->BufferCount();
}
InferenceRequest::SequenceId::SequenceId()
: sequence_label_(""), sequence_index_(0),
id_type_(InferenceRequest::SequenceId::DataType::UINT64)
{
}
InferenceRequest::SequenceId::SequenceId(const std::string& sequence_label)
: sequence_label_(sequence_label), sequence_index_(0),
id_type_(InferenceRequest::SequenceId::DataType::STRING)
{
}
InferenceRequest::SequenceId::SequenceId(uint64_t sequence_index)
: sequence_label_(""), sequence_index_(sequence_index),
id_type_(InferenceRequest::SequenceId::DataType::UINT64)
{
}
InferenceRequest::SequenceId&
InferenceRequest::SequenceId::operator=(const std::string& rhs)
{
sequence_label_ = rhs;
sequence_index_ = 0;
id_type_ = InferenceRequest::SequenceId::DataType::STRING;
return *this;
}
InferenceRequest::SequenceId&
InferenceRequest::SequenceId::operator=(const uint64_t rhs)
{
sequence_label_ = "";
sequence_index_ = rhs;
id_type_ = InferenceRequest::SequenceId::DataType::UINT64;
return *this;
}
std::ostream&
operator<<(std::ostream& out, const InferenceRequest& request)
{
out << "[0x" << std::addressof(request) << "] "
<< "request id: " << request.Id() << ", model: " << request.ModelName()
<< ", requested version: " << request.RequestedModelVersion()
<< ", actual version: " << request.ActualModelVersion() << ", flags: 0x"
<< std::hex << request.Flags() << std::dec
<< ", correlation id: " << request.CorrelationId()
<< ", batch size: " << request.BatchSize()
<< ", priority: " << request.Priority()
<< ", timeout (us): " << request.TimeoutMicroseconds() << std::endl;
out << "original inputs:" << std::endl;
for (const auto& itr : request.OriginalInputs()) {
out << "[0x" << std::addressof(itr.second) << "] " << itr.second
<< std::endl;
}
out << "override inputs:" << std::endl;
for (const auto& itr : request.OverrideInputs()) {
out << "[0x" << itr.second.get() << "] " << *itr.second << std::endl;
}
out << "inputs:" << std::endl;
for (const auto& itr : request.ImmutableInputs()) {
out << "[0x" << itr.second << "] " << *itr.second << std::endl;
}
out << "original requested outputs:" << std::endl;
for (const auto& name : request.OriginalRequestedOutputs()) {
out << name << std::endl;
}
out << "requested outputs:" << std::endl;
for (const auto& name : request.ImmutableRequestedOutputs()) {
out << name << std::endl;
}
return out;
}
std::ostream&
operator<<(std::ostream& out, const InferenceRequest::Input& input)
{
out << "input: " << input.Name()
<< ", type: " << triton::common::DataTypeToProtocolString(input.DType())
<< ", original shape: "
<< triton::common::DimsListToString(input.OriginalShape())
<< ", batch + shape: "
<< triton::common::DimsListToString(input.ShapeWithBatchDim())
<< ", shape: " << triton::common::DimsListToString(input.Shape());
if (input.IsShapeTensor()) {
out << ", is_shape_tensor: True";
}
return out;
}
std::ostream&
operator<<(std::ostream& out, const InferenceRequest::SequenceId& sequence_id)
{
switch (sequence_id.Type()) {
case InferenceRequest::SequenceId::DataType::STRING:
out << sequence_id.StringValue();
break;
case InferenceRequest::SequenceId::DataType::UINT64:
out << sequence_id.UnsignedIntValue();
break;
default:
out << sequence_id.UnsignedIntValue();
break;
}
return out;
}
bool
operator==(
const InferenceRequest::SequenceId lhs,
const InferenceRequest::SequenceId rhs)
{
if (lhs.Type() == rhs.Type()) {
switch (lhs.Type()) {
case InferenceRequest::SequenceId::DataType::STRING:
return lhs.StringValue() == rhs.StringValue();
case InferenceRequest::SequenceId::DataType::UINT64:
return lhs.UnsignedIntValue() == rhs.UnsignedIntValue();
default:
return lhs.UnsignedIntValue() == rhs.UnsignedIntValue();
}
} else {
return false;
}
}
bool
operator!=(
const InferenceRequest::SequenceId lhs,
const InferenceRequest::SequenceId rhs)
{
return !(lhs == rhs);
}
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include <vector>
#include "buffer_attributes.h"
#include "infer_response.h"
#include "infer_stats.h"
#include "infer_trace.h"
#include "memory.h"
#include "response_allocator.h"
#include "sequence_state.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace triton { namespace core {
class Model;
class InferenceServer;
class MetricModelReporter;
//
// An inference request. A request can be used multiple times for
// inference but before each inference run, PrepareForInference() must
// be called to verify and prepare the request. Verification involves
// ensuring that any changes made since the last inference are
// valid. Preparing involves removing/resetting any state left over
// from the previous inference.
//
class InferenceRequest {
public:
// Input tensor
class Input {
public:
Input();
Input(
const std::string& name, const inference::DataType datatype,
const std::vector<int64_t>& shape);
Input(
const std::string& name, const inference::DataType datatype,
const int64_t* shape, const uint64_t dim_count);
// Set the name, data type and original shape of the input tensor.
void SetMetadata(
const std::string& name, const inference::DataType& dt,
const std::vector<int64_t>& shape);
// The name of the input tensor. There is no mutable operator for
// the name because it is used in a InferenceRequest map and a
// mutable method would allow it to get out-of-sync.
const std::string& Name() const { return name_; }
// Data type of the input tensor.
inference::DataType DType() const { return datatype_; }
// The original shape of the input tensor.
const std::vector<int64_t>& OriginalShape() const
{
return original_shape_;
}
// The shape of the input tensor after normalization. This shape
// is the original shape modified as required/expected by
// inference processing.
const std::vector<int64_t>& Shape() const { return shape_; }
std::vector<int64_t>* MutableShape() { return &shape_; }
// FIXME. Should not need these functions. All shapes kept here
// should include the batch dimension instead of breaking the same
// into batch + shape.
const std::vector<int64_t>& ShapeWithBatchDim() const
{
return shape_with_batch_dim_;
}
std::vector<int64_t>* MutableShapeWithBatchDim()
{
return &shape_with_batch_dim_;
}
// Return true if host-specific data was added for this input
bool HasHostPolicySpecificData() const
{
return has_host_policy_specific_data_;
}
// Whether or not the input is a tensorrt shape tensor
bool IsShapeTensor() const { return is_shape_tensor_; }
// Set the input to be treated as a shape tensor.
Status SetIsShapeTensor(const bool is_shape_tensor);
// The data for this input.
const std::shared_ptr<Memory>& Data() const { return data_; }
// The data for this input for a specific device
const std::shared_ptr<Memory>& Data(
const std::string& host_policy_name) const;
// Return all host policy data set for this input
const std::map<std::string, std::shared_ptr<Memory>>& HostPolicyData() const
{
return host_policy_data_map_;
}
// Set the data for this input. Error if input already has some
// data.
Status SetData(const std::shared_ptr<Memory>& data);
// Set the data associated with the host policy for this input.
// Return error if input already has some data.
Status SetData(
const std::string& host_policy_name,
const std::shared_ptr<Memory>& data);
// Append a new buffer of data to this input.
Status AppendData(
const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id);
Status AppendDataWithHostPolicy(
const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id, const char* host_policy_name);
Status AppendDataWithBufferAttributes(
const void* base, BufferAttributes* buffer_attributes);
// Prepend a new buffer of data to this input.
Status PrependData(
const void* base, size_t byte_size, TRITONSERVER_MemoryType memory_type,
int64_t memory_type_id);
// Remove all existing data for the input.
Status RemoveAllData();
// Get the number of buffers containing the input tensor data.
size_t DataBufferCount() const { return data_->BufferCount(); }
// Get the number of buffers containing the input tensor data with
// host policy. If there are no buffers corresponding to the specific
// host policy, the number of buffers in the fallback input data is
// returned.
size_t DataBufferCountForHostPolicy(
const std::string& host_policy_name) const;
// Get the 'idx' buffer containing a contiguous chunk of bytes for
// the input. Return error is 'idx' refers to a buffer that does
// not exist. Return a pointer to the chunk in 'base' and the
// size of the chunk in 'byte_size'. 'memory_type' acts as
// both input and output. On input 'memory_type' is the buffer
// memory type preferred by the function caller. On return
// 'memory_type' gives the actual memory type of the chunk pointed
// to by 'base'. 'memory_type_id' acts as both input and
// output. On input 'memory_type_id' is the buffer memory type id
// preferred by the function caller. On return 'memory_type_id'
// gives the actual memory type id of the chunk pointed to by
// 'base'.
Status DataBuffer(
const size_t idx, const void** base, size_t* byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id) const;
// Get the buffer attributes associated with 'idx' buffer.
Status DataBufferAttributes(
const size_t idx, const void** base,
BufferAttributes** buffer_attributes) const;
// Get the 'idx' buffer containing a contiguous chunk of bytes for
// the input. Return error is 'idx' refers to a buffer that does
// not exist. Return a pointer to the chunk in 'base' and the
// size of the chunk in 'byte_size'. 'memory_type' acts as
// both input and output. On input 'memory_type' is the buffer
// memory type preferred by the function caller. On return
// 'memory_type' gives the actual memory type of the chunk pointed
// to by 'base'. 'memory_type_id' acts as both input and
// output. On input 'memory_type_id' is the buffer memory type id
// preferred by the function caller. On return 'memory_type_id'
// gives the actual memory type id of the chunk pointed to by
// 'base'.
Status DataBufferForHostPolicy(
const size_t idx, const void** base, size_t* byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id,
const std::string& host_policy_name) const;
private:
DISALLOW_COPY_AND_ASSIGN(Input);
friend std::ostream& operator<<(
std::ostream& out, const InferenceRequest::Input& input);
std::string name_;
inference::DataType datatype_;
std::vector<int64_t> original_shape_;
std::vector<int64_t> shape_;
std::vector<int64_t> shape_with_batch_dim_;
bool is_shape_tensor_;
std::shared_ptr<Memory> data_;
bool has_host_policy_specific_data_;
// A map of host policy to input data memory
std::map<std::string, std::shared_ptr<Memory>> host_policy_data_map_;
};
// Sequence ID can be either a 64 bit integer or a string.
// This class implements the SequenceId type
class SequenceId {
public:
enum class DataType { UINT64, STRING };
SequenceId();
SequenceId(const std::string& sequence_label);
SequenceId(uint64_t sequence_index);
SequenceId& operator=(const SequenceId& rhs) = default;
SequenceId& operator=(const std::string& rhs);
SequenceId& operator=(const uint64_t rhs);
// Functions that help determine exact type of sequence Id
DataType Type() const { return id_type_; }
bool InSequence() const
{
return ((sequence_label_ != "") || (sequence_index_ != 0));
}
// Get the value of the SequenceId based on the type
const std::string& StringValue() const { return sequence_label_; }
uint64_t UnsignedIntValue() const { return sequence_index_; }
private:
friend std::ostream& operator<<(
std::ostream& out, const InferenceRequest::SequenceId& correlation_id);
friend bool operator==(const SequenceId lhs, const SequenceId rhs);
friend bool operator!=(const SequenceId lhs, const SequenceId rhs);
std::string sequence_label_;
uint64_t sequence_index_;
DataType id_type_;
};
// InferenceRequest
//
// The two constructors are identical except one takes model as a
// shared pointer and the other as a raw pointer. The shared pointer
// version is the primary one and acts to keep the model alive as
// long as the request is in flight. The raw pointer version is used
// only for cases where the model itself is issuing a request
// (e.g. warmup) and no shared pointer version of the model exists
// (because we aren't using shared_from_this).
InferenceRequest(
const std::shared_ptr<Model>& model,
const int64_t requested_model_version);
InferenceRequest(Model* model, const int64_t requested_model_version);
const std::string& ModelName() const;
int64_t RequestedModelVersion() const { return requested_model_version_; }
int64_t ActualModelVersion() const;
const std::string& Id() const { return id_; }
void SetId(const std::string& i) { id_ = i; }
// Return string for logging request ID
std::string LogRequest() const
{
std::string id = Id();
if (id.empty()) {
id = "<id_unknown>";
}
return std::string("[request id: ") + id + "] ";
}
// Flags for the request, union of TRITONSERVER_RequestFlag.
uint32_t Flags() const { return flags_; }
void SetFlags(uint32_t f) { flags_ = f; }
const SequenceId& CorrelationId() const { return correlation_id_; }
void SetCorrelationId(const SequenceId& c) { correlation_id_ = c; }
// The batch size of the request, as understood by Triton. A
// batch-size of 0 indicates that the model doesn't support batching
// in a way that Triton understands. Batch size is not set
// explicitly so there is no setter for it. It is set when the
// request is normalized.
uint32_t BatchSize() const { return batch_size_; }
uint32_t Priority() const { return priority_; }
void SetPriority(uint32_t p);
uint64_t TimeoutMicroseconds() const { return timeout_us_; }
void SetTimeoutMicroseconds(uint64_t t) { timeout_us_ = t; }
uint64_t CacheKey() const { return cache_key_; }
// It is up to the user to update the cache_key_ if modifying any hashable
// fields of the request after cache_key_is_set_ has been set to true.
void SetCacheKey(uint64_t key)
{
cache_key_ = key;
cache_key_is_set_ = true;
}
bool CacheKeyIsSet() const { return cache_key_is_set_; }
#ifdef TRITON_ENABLE_TRACING
const std::shared_ptr<InferenceTraceProxy>& Trace() const { return trace_; }
std::shared_ptr<InferenceTraceProxy>* MutableTrace() { return &trace_; }
void SetTrace(const std::shared_ptr<InferenceTraceProxy>& trace)
{
trace_ = trace;
response_factory_->SetTrace(trace);
}
void ReleaseTrace()
{
trace_ = nullptr;
response_factory_->ReleaseTrace();
}
Status TraceInputTensors(
TRITONSERVER_InferenceTraceActivity activity, const std::string& msg);
#endif // TRITON_ENABLE_TRACING
// The original inputs are the inputs added to the request before
// the inference execution (that is before
// TRITONSERVER_ServerInferAsync is called). Once execution has
// started the original inputs should not be modified until
// execution completes (and those modifications will apply to the
// next inference execution).
Status MutableOriginalInput(const std::string& name, Input** input);
std::unordered_map<std::string, Input>* MutableOriginalInputs()
{
return &original_inputs_;
}
const std::unordered_map<std::string, Input>& OriginalInputs() const
{
return original_inputs_;
}
// The override inputs are the inputs added to the request after
// inference execution has started (that is after
// TRITONSERVER_ServerInferAsync or equivalent is called). During
// inference processing, if Triton needs to change an original input
// it will add an override instead of changing the original. Triton
// will also use an override if it needs to add a new input to the
// request. Overrides are recorded as shared_ptr so that the same
// override can be used efficiently multiple times or even in
// multiple requests simultaneously. Must be careful not to modify
// an override input if it is being shared unless you want that
// change to be reflected in all requests that hold that override
// input. Override inputs within a specific request are not
// persisted across inference calls.
std::unordered_map<std::string, std::shared_ptr<Input>>*
MutableOverrideInputs()
{
return &override_inputs_;
}
const std::unordered_map<std::string, std::shared_ptr<Input>>&
OverrideInputs() const
{
return override_inputs_;
}
// Get an input taking into account both original inputs and
// overrides. If an override input is available use it, otherwise
// use the original input. Accessing inputs via this method is not
// valid until after PrepareForInference is called.
Status ImmutableInput(const std::string& name, const Input** input) const;
const std::unordered_map<std::string, Input*>& ImmutableInputs() const
{
return inputs_;
}
// The original requested outputs are the requested outputs added to
// the request before the inference execution (that is before
// TRITONSERVER_ServerInferAsync is called). Once execution has
// started the original requested outputs should not be modified
// until execution completes (and those modifications will apply to
// the next inference execution).
const std::set<std::string>& OriginalRequestedOutputs() const
{
return original_requested_outputs_;
}
// Get the requested outputs that should be used during
// inference. Accessing outputs via this method is not valid until
// after PrepareForInference is called.
const std::set<std::string>& ImmutableRequestedOutputs() const
{
return (requested_outputs_.empty()) ? original_requested_outputs_
: requested_outputs_;
}
// Get the response factory.
const std::shared_ptr<InferenceResponseFactory>& ResponseFactory() const
{
return response_factory_;
}
// Add an original input to the request. If 'input' is non-null
// return a pointer to the newly added input.
Status AddOriginalInput(
const std::string& name, const inference::DataType datatype,
const int64_t* shape, const uint64_t dim_count, Input** input = nullptr);
Status AddOriginalInput(
const std::string& name, const inference::DataType datatype,
const std::vector<int64_t>& shape, Input** input = nullptr);
// Add an original raw input to the request. If 'input' is non-null
// return a pointer to the newly added input.
Status AddRawInput(const std::string& name, Input** input = nullptr);
// Remove a single original input or all inputs.
Status RemoveOriginalInput(const std::string& name);
Status RemoveAllOriginalInputs();
// Add an override input to the request. If 'input' is non-null
// return a pointer to the newly added input.
// FIXME passing batch size is special handling for backend API.
// For override input, the 'shape' is without batch dimension for
// backends that implemented w/o backend API (which need correct
// input.Shape()), but backend API uses input.ShapeWithBatchDim().
Status AddOverrideInput(
const std::string& name, const inference::DataType datatype,
const int64_t batch_size, const std::vector<int64_t>& shape,
std::shared_ptr<Input>* input = nullptr);
// Add an override input to the request.
Status AddOverrideInput(const std::shared_ptr<Input>& input);
// Request an original requested output.
Status AddOriginalRequestedOutput(const std::string& name);
// Remove a single original requested output or all requested
// outputs.
Status RemoveOriginalRequestedOutput(const std::string& name);
Status RemoveAllOriginalRequestedOutputs();
// Initialize the release callback for the request.
Status SetReleaseCallback(
TRITONSERVER_InferenceRequestReleaseFn_t release_fn, void* release_userp)
{
release_fn_ = release_fn;
release_userp_ = release_userp;
return Status::Success;
}
// Initialize the response factory that is to be used with any
// responses produced for this request.
Status SetResponseCallback(
const ResponseAllocator* allocator, void* alloc_userp,
TRITONSERVER_InferenceResponseCompleteFn_t response_fn,
void* response_userp)
{
response_factory_.reset(new InferenceResponseFactory(
model_shared_, id_, allocator, alloc_userp, response_fn, response_userp,
response_delegator_));
return Status::Success;
}
// Returns the preferred memory type and memory type ID of the output buffer
// for the request. 'name' and 'byte_size' are optional and set to nullptr
// if not specified, if provided, they give the allocator more information.
// 'memory_type' and 'memory_type_id' are also used as input to provide types
// preferred by the caller.
// Status::Code::UNAVAILABLE will be returned if output properties are not
// available.
Status OutputBufferProperties(
const char* name, size_t* byte_size, TRITONSERVER_MemoryType* memory_type,
int64_t* memory_type_id);
// Add a callback to be invoked on releasing the request object from Triton.
// Multile callbacks can be added by calling this function in order,
// and they will be invoked in reversed order.
Status AddInternalReleaseCallback(std::function<void()>&& callback)
{
release_callbacks_.emplace_back(std::move(callback));
return Status::Success;
}
// Add a delegator to be invoked on sending the responses of this request.
// The response will be passed to 'delegator' and 'delegator' must call the
// InferenceResponse::Send() to send the response.
Status SetResponseDelegator(
std::function<void(
std::unique_ptr<InferenceResponse>&&, const uint32_t)>&& delegator)
{
response_delegator_ = std::move(delegator);
return response_factory_->SetResponseDelegator(response_delegator_);
}
Status SetSequenceStates(
const std::shared_ptr<SequenceStates>& sequence_states)
{
sequence_states_ = sequence_states;
return Status::Success;
}
Status LoadInputStates();
const std::shared_ptr<SequenceStates>& GetSequenceStates() const
{
return sequence_states_;
}
// Prepare this request for inference.
Status PrepareForInference();
// Run this inference request using the model associated with the
// request. If Status::Success is returned then the call has taken
// ownership of the request object and so 'request' will be
// nullptr. If non-success is returned then the caller still retains
// ownership of 'request'.
static Status Run(std::unique_ptr<InferenceRequest>& request);
// Send an error response for this request. If 'status' is Success
// then no response is sent and the request is not released (even if
// 'release_request' is true). Because this is sending an error it
// is assumed that this is the last response for the request and so
// the FINAL flag is set in the response callback. If
// 'release_request' is true then the release callback is called for
// this request and ownership is given to the callback. Thus, if
// 'release_request' is true 'request' is returned as nullptr.
static void RespondIfError(
std::unique_ptr<InferenceRequest>& request, const Status& status,
const bool release_request = false);
// Send an error response to a set of 'requests'. If 'status' is
// Success then no responses are sent and the requests are not
// released (even if 'release_request' is true). Because this is
// sending an error it is assumed that this is the last response for
// the requests and so the FINAL flag is set in the response
// callbacks. If 'release_request' is true then the release callback
// is called for each request, and the request ownership is given to
// the callback. Thus, if 'release_request' is true 'requests' is
// returned with all nullptrs.
static void RespondIfError(
std::vector<std::unique_ptr<InferenceRequest>>& requests,
const Status& status, const bool release_requests = false);
// Release the request. Call the release callback and transfer
// ownership of the request to the callback. On return 'request' is
// nullptr.
static void Release(
std::unique_ptr<InferenceRequest>&& request,
const uint32_t release_flags);
// Create a copy of 'from' suitable for use as a "null" request as
// required for the direct sequence batcher. The returned copy will
// contain only the minimum content required for a null request.
// The statistics of the copy will not be collected.
static InferenceRequest* CopyAsNull(const InferenceRequest& from);
uint64_t QueueStartNs() const { return queue_start_ns_; }
uint64_t CaptureQueueStartNs()
{
queue_start_ns_ = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
return queue_start_ns_;
}
uint64_t CacheLookupStartNs() const { return cache_lookup_start_ns_; }
uint64_t CaptureCacheLookupStartNs()
{
cache_lookup_start_ns_ =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
return cache_lookup_start_ns_;
}
uint64_t CacheLookupEndNs() const { return cache_lookup_end_ns_; }
uint64_t CaptureCacheLookupEndNs()
{
cache_lookup_end_ns_ =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
return cache_lookup_end_ns_;
}
uint64_t CacheInsertionStartNs() const { return cache_insertion_start_ns_; }
uint64_t CaptureCacheInsertionStartNs()
{
cache_insertion_start_ns_ =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
return cache_insertion_start_ns_;
}
uint64_t CacheInsertionEndNs() const { return cache_insertion_end_ns_; }
uint64_t CaptureCacheInsertionEndNs()
{
cache_insertion_end_ns_ =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
return cache_insertion_end_ns_;
}
uint64_t BatcherStartNs() const { return batcher_start_ns_; }
uint64_t CaptureBatcherStartNs()
{
batcher_start_ns_ = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
return batcher_start_ns_;
}
#ifdef TRITON_ENABLE_STATS
uint64_t RequestStartNs() const { return request_start_ns_; }
uint64_t CaptureRequestStartNs()
{
request_start_ns_ = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
return request_start_ns_;
}
// Report the statistics to stats collectors associated with the request.
// Duration and timestamps provide two granularities for stats collectors.
void ReportStatistics(
MetricModelReporter* metric_reporter, bool success,
const uint64_t compute_start_ns, const uint64_t compute_input_end_ns,
const uint64_t compute_output_start_ns, const uint64_t compute_end_ns);
// Report the statistics to stats collectors associated with the request.
// Duration and timestamps provide two granularities for stats collectors.
void ReportStatisticsWithDuration(
MetricModelReporter* metric_reporter, bool success,
const uint64_t compute_start_ns, const uint64_t compute_input_duration_ns,
const uint64_t compute_infer_duration_ns,
const uint64_t compute_output_duration_ns);
// Report the statistics to stats collectors associated with the request on
// response cache hits.
void ReportStatisticsCacheHit(MetricModelReporter* metric_reporter);
// Report the statistics to stats collectors associated with the request on
// response cache misses and update request duration to include cache
// insertion time.
void ReportStatisticsCacheMiss(MetricModelReporter* metric_reporter);
// Statistics for each request are aggregated into the corresponding
// model's statistics. Optionally this function may be used to
// add an additional aggregator where statistics are also aggregated.
void SetSecondaryStatsAggregator(
InferenceStatsAggregator* secondary_stats_aggregator)
{
secondary_stats_aggregator_ = secondary_stats_aggregator;
}
#endif // TRITON_ENABLE_STATS
private:
DISALLOW_COPY_AND_ASSIGN(InferenceRequest);
friend std::ostream& operator<<(
std::ostream& out, const InferenceRequest& request);
Status Normalize();
// Has anything in the request potentially changed in a way that
// causes normalization to be required when preparing the request
// for inference.
bool needs_normalization_;
// The model associated with this request. For most requests
// model_shared_ will be non-null and will act to keep the model
// alive as long as this request is live. In this case model_raw_
// will be the raw pointer from the shared pointer. For cases where
// the model itself created the request (like running requests for
// warmup), model_shared_ will be nullptr, but model_raw_ will
// still be defined. Thus model_raw_ is always defined and should
// always to used to access the model.
std::shared_ptr<Model> model_shared_;
Model* model_raw_;
// The model version as requested and based on version policy the
// specific version that is actually used for inference.
int64_t requested_model_version_;
int64_t actual_model_version_;
std::string id_;
uint32_t flags_;
SequenceId correlation_id_;
uint32_t batch_size_;
uint32_t priority_;
uint64_t timeout_us_;
uint64_t cache_key_ = 0;
// Helper to determine if request was successfully hashed
// and cache_key_ field is valid
bool cache_key_is_set_ = false;
std::unordered_map<std::string, Input> original_inputs_;
std::unordered_map<std::string, std::shared_ptr<Input>> override_inputs_;
std::unordered_map<std::string, Input*> inputs_;
std::set<std::string> original_requested_outputs_;
std::string raw_input_name_;
uint32_t raw_input_size_;
// requested_outputs_ is to be used post-normalization. It will be
// empty unless it differs from original_requested_outputs_, so
// typically should access it through ImmutableRequestedOutputs.
std::set<std::string> requested_outputs_;
// The release function and user pointer for this request.
TRITONSERVER_InferenceRequestReleaseFn_t release_fn_;
void* release_userp_;
// Additional release callbacks invoked before 'release_fn_'.
std::vector<std::function<void()>> release_callbacks_;
// Delegator to be invoked on sending responses.
std::function<void(std::unique_ptr<InferenceResponse>&&, const uint32_t)>
response_delegator_;
// The response factory associated with this request.
std::shared_ptr<InferenceResponseFactory> response_factory_;
// Request timestamps. Queue start is needed for schedulers even
// when statistics are not being collected.
uint64_t queue_start_ns_;
// Cache lookup start/end timestamps. Cache manages its own stats even
// when statistics are not being colleceted.
uint64_t cache_lookup_start_ns_;
uint64_t cache_lookup_end_ns_;
// Cache insertion start/end timestamps. Cache manages its own stats even
// when statistics are not being colleceted.
uint64_t cache_insertion_start_ns_;
uint64_t cache_insertion_end_ns_;
// Dedicated timestamp for batcher internal which can diverge from
// queue start timestamp to provide accurate queue time without affecting
// batcher functionalities.
uint64_t batcher_start_ns_;
// Whether the stats of the request should be collected.
bool collect_stats_;
#ifdef TRITON_ENABLE_STATS
uint64_t request_start_ns_;
InferenceStatsAggregator* secondary_stats_aggregator_ = nullptr;
#endif // TRITON_ENABLE_STATS
#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this request.
std::shared_ptr<InferenceTraceProxy> trace_;
#endif // TRITON_ENABLE_TRACING
// Sequence I/O states used for implicit state.
std::shared_ptr<SequenceStates> sequence_states_;
};
std::ostream& operator<<(std::ostream& out, const InferenceRequest& request);
std::ostream& operator<<(
std::ostream& out, const InferenceRequest::Input& input);
std::ostream& operator<<(
std::ostream& out, const InferenceRequest::SequenceId& sequence_id);
bool operator==(
const InferenceRequest::SequenceId lhs,
const InferenceRequest::SequenceId rhs);
}} // namespace triton::core
namespace std {
using namespace triton::core;
template <>
class hash<InferenceRequest::SequenceId> {
public:
size_t operator()(const InferenceRequest::SequenceId& sequence_id) const
{
if (sequence_id.Type() == InferenceRequest::SequenceId::DataType::STRING) {
return std::hash<std::string>{}(sequence_id.StringValue());
}
return std::hash<uint64_t>{}(sequence_id.UnsignedIntValue());
}
};
} // namespace std
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_response.h"
#include "model.h"
#include "model_config_utils.h"
#include "server.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
//
// InferenceResponseFactory
//
Status
InferenceResponseFactory::CreateResponse(
std::unique_ptr<InferenceResponse>* response) const
{
response->reset(new InferenceResponse(
model_, id_, allocator_, alloc_userp_, response_fn_, response_userp_,
response_delegator_));
#ifdef TRITON_ENABLE_TRACING
(*response)->SetTrace(trace_);
#endif // TRITON_ENABLE_TRACING
return Status::Success;
}
Status
InferenceResponseFactory::SendFlags(const uint32_t flags) const
{
if (response_delegator_ != nullptr) {
std::unique_ptr<InferenceResponse> response(
new InferenceResponse(response_fn_, response_userp_));
response_delegator_(std::move(response), flags);
} else {
void* userp = response_userp_;
response_fn_(nullptr /* response */, flags, userp);
}
return Status::Success;
}
//
// InferenceResponse
//
InferenceResponse::InferenceResponse(
const std::shared_ptr<Model>& model, const std::string& id,
const ResponseAllocator* allocator, void* alloc_userp,
TRITONSERVER_InferenceResponseCompleteFn_t response_fn,
void* response_userp,
const std::function<
void(std::unique_ptr<InferenceResponse>&&, const uint32_t)>& delegator)
: model_(model), id_(id), allocator_(allocator), alloc_userp_(alloc_userp),
response_fn_(response_fn), response_userp_(response_userp),
response_delegator_(delegator), null_response_(false)
{
// If the allocator has a start_fn then invoke it.
TRITONSERVER_ResponseAllocatorStartFn_t start_fn = allocator_->StartFn();
if (start_fn != nullptr) {
LOG_TRITONSERVER_ERROR(
start_fn(
reinterpret_cast<TRITONSERVER_ResponseAllocator*>(
const_cast<ResponseAllocator*>(allocator_)),
alloc_userp_),
"response allocation start failed");
}
}
InferenceResponse::InferenceResponse(
TRITONSERVER_InferenceResponseCompleteFn_t response_fn,
void* response_userp)
: response_fn_(response_fn), response_userp_(response_userp),
null_response_(true)
{
}
const std::string&
InferenceResponse::ModelName() const
{
static const std::string unknown("<unknown>");
return (model_ == nullptr) ? unknown : model_->Name();
}
int64_t
InferenceResponse::ActualModelVersion() const
{
return (model_ == nullptr) ? -1 : model_->Version();
}
Status
InferenceResponse::AddParameter(const char* name, const char* value)
{
parameters_.emplace_back(name, value);
return Status::Success;
}
Status
InferenceResponse::AddParameter(const char* name, const int64_t value)
{
parameters_.emplace_back(name, value);
return Status::Success;
}
Status
InferenceResponse::AddParameter(const char* name, const bool value)
{
parameters_.emplace_back(name, value);
return Status::Success;
}
Status
InferenceResponse::AddOutput(
const std::string& name, const inference::DataType datatype,
const std::vector<int64_t>& shape, InferenceResponse::Output** output)
{
outputs_.emplace_back(name, datatype, shape, allocator_, alloc_userp_);
LOG_VERBOSE(1) << "add response output: " << outputs_.back();
if (model_ != nullptr) {
const inference::ModelOutput* output_config;
RETURN_IF_ERROR(model_->GetOutput(name, &output_config));
if (output_config->has_reshape()) {
const bool has_batch_dim = (model_->Config().max_batch_size() > 0);
outputs_.back().Reshape(has_batch_dim, output_config);
}
}
if (output != nullptr) {
*output = std::addressof(outputs_.back());
}
return Status::Success;
}
Status
InferenceResponse::AddOutput(
const std::string& name, const inference::DataType datatype,
std::vector<int64_t>&& shape, InferenceResponse::Output** output)
{
outputs_.emplace_back(
name, datatype, std::move(shape), allocator_, alloc_userp_);
LOG_VERBOSE(1) << "add response output: " << outputs_.back();
if (model_ != nullptr) {
const inference::ModelOutput* output_config;
RETURN_IF_ERROR(model_->GetOutput(name, &output_config));
if (output_config->has_reshape()) {
const bool has_batch_dim = (model_->Config().max_batch_size() > 0);
outputs_.back().Reshape(has_batch_dim, output_config);
}
}
if (output != nullptr) {
*output = std::addressof(outputs_.back());
}
return Status::Success;
}
Status
InferenceResponse::ClassificationLabel(
const InferenceResponse::Output& output, const uint32_t class_index,
const char** label) const
{
const auto& label_provider = model_->GetLabelProvider();
const std::string& l = label_provider->GetLabel(output.Name(), class_index);
if (l.empty()) {
*label = nullptr;
} else {
*label = l.c_str();
}
return Status::Success;
}
Status
InferenceResponse::Send(
std::unique_ptr<InferenceResponse>&& response, const uint32_t flags)
{
#ifdef TRITON_ENABLE_TRACING
response->TraceOutputTensors(
TRITONSERVER_TRACE_TENSOR_BACKEND_OUTPUT, "InferenceResponse Send");
#endif // TRITON_ENABLE_TRACING
if (response->response_delegator_ != nullptr) {
auto ldelegator = std::move(response->response_delegator_);
ldelegator(std::move(response), flags);
return Status::Success;
}
void* userp = response->response_userp_;
if (response->null_response_) {
response->response_fn_(nullptr /* response */, flags, userp);
} else {
auto& response_fn = response->response_fn_;
response_fn(
reinterpret_cast<TRITONSERVER_InferenceResponse*>(response.release()),
flags, userp);
}
return Status::Success;
}
Status
InferenceResponse::SendWithStatus(
std::unique_ptr<InferenceResponse>&& response, const uint32_t flags,
const Status& status)
{
response->status_ = status;
return InferenceResponse::Send(std::move(response), flags);
}
#ifdef TRITON_ENABLE_TRACING
Status
InferenceResponse::TraceOutputTensors(
TRITONSERVER_InferenceTraceActivity activity, const std::string& msg)
{
const auto& outputs = this->Outputs();
uint32_t output_count = outputs.size();
for (uint32_t idx = 0; idx < output_count; ++idx) {
const Output& output = outputs[idx];
// output data
const char* cname = output.Name().c_str();
TRITONSERVER_DataType datatype = DataTypeToTriton(output.DType());
const std::vector<int64_t>& oshape = output.Shape();
const int64_t* shape = &oshape[0];
uint64_t dim_count = oshape.size();
const void* base;
size_t byte_size;
TRITONSERVER_MemoryType memory_type;
int64_t memory_type_id;
void* userp;
Status status = output.DataBuffer(
&base, &byte_size, &memory_type, &memory_type_id, &userp);
if (!status.IsOk()) {
LOG_STATUS_ERROR(
status,
std::string(TRITONSERVER_InferenceTraceActivityString(activity)) +
": " + msg + ": fail to get data buffer: " + status.Message());
return status;
}
INFER_TRACE_TENSOR_ACTIVITY(
this->trace_, activity, cname, datatype, base, byte_size, shape,
dim_count, memory_type, memory_type_id);
}
return Status::Success;
}
#endif // TRITON_ENABLE_TRACING
//
// InferenceResponse::Output
//
InferenceResponse::Output::~Output()
{
Status status = ReleaseDataBuffer();
if (!status.IsOk()) {
LOG_ERROR << "failed to release buffer for output '" << name_
<< "': " << status.AsString();
}
}
void
InferenceResponse::Output::Reshape(
const bool has_batch_dim, const inference::ModelOutput* output_config)
{
std::deque<int64_t> variable_size_values;
const int64_t batch_dim =
(has_batch_dim && (shape_.size() > 0)) ? shape_[0] : -1;
const size_t batch_dim_offset = (has_batch_dim) ? 1 : 0;
const auto& from_shape = output_config->reshape().shape();
const auto& to_shape = output_config->dims();
for (int64_t idx = 0; idx < from_shape.size(); idx++) {
if (from_shape[idx] == -1) {
variable_size_values.push_back(shape_[idx + batch_dim_offset]);
}
}
shape_.clear();
if (batch_dim >= 0) {
shape_.push_back(batch_dim);
}
for (const auto& dim : to_shape) {
if (dim == -1) {
shape_.push_back(variable_size_values.front());
variable_size_values.pop_front();
} else {
shape_.push_back(dim);
}
}
}
Status
InferenceResponse::Output::DataBuffer(
const void** buffer, size_t* buffer_byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id,
void** userp) const
{
*buffer = allocated_buffer_;
*buffer_byte_size = buffer_attributes_.ByteSize();
*memory_type = buffer_attributes_.MemoryType();
*memory_type_id = buffer_attributes_.MemoryTypeId();
*userp = allocated_userp_;
return Status::Success;
}
Status
InferenceResponse::Output::AllocateDataBuffer(
void** buffer, size_t buffer_byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id)
{
if (allocated_buffer_ != nullptr) {
return Status(
Status::Code::ALREADY_EXISTS,
"allocated buffer for output '" + name_ + "' already exists");
}
TRITONSERVER_MemoryType actual_memory_type = *memory_type;
int64_t actual_memory_type_id = *memory_type_id;
void* alloc_buffer_userp = nullptr;
RETURN_IF_TRITONSERVER_ERROR(allocator_->AllocFn()(
reinterpret_cast<TRITONSERVER_ResponseAllocator*>(
const_cast<ResponseAllocator*>(allocator_)),
name_.c_str(), buffer_byte_size, *memory_type, *memory_type_id,
alloc_userp_, buffer, &alloc_buffer_userp, &actual_memory_type,
&actual_memory_type_id));
// Only call the buffer attributes API if it is set.
if (allocator_->BufferAttributesFn() != nullptr) {
RETURN_IF_TRITONSERVER_ERROR(allocator_->BufferAttributesFn()(
reinterpret_cast<TRITONSERVER_ResponseAllocator*>(
const_cast<ResponseAllocator*>(allocator_)),
name_.c_str(),
reinterpret_cast<TRITONSERVER_BufferAttributes*>(&buffer_attributes_),
alloc_userp_, alloc_buffer_userp));
}
allocated_buffer_ = *buffer;
buffer_attributes_.SetByteSize(buffer_byte_size);
buffer_attributes_.SetMemoryType(actual_memory_type);
buffer_attributes_.SetMemoryTypeId(actual_memory_type_id);
allocated_userp_ = alloc_buffer_userp;
*memory_type = actual_memory_type;
*memory_type_id = actual_memory_type_id;
return Status::Success;
}
Status
InferenceResponse::Output::ReleaseDataBuffer()
{
TRITONSERVER_Error* err = nullptr;
if (allocated_buffer_ != nullptr) {
err = allocator_->ReleaseFn()(
reinterpret_cast<TRITONSERVER_ResponseAllocator*>(
const_cast<ResponseAllocator*>(allocator_)),
allocated_buffer_, allocated_userp_, buffer_attributes_.ByteSize(),
buffer_attributes_.MemoryType(), buffer_attributes_.MemoryTypeId());
}
allocated_buffer_ = nullptr;
buffer_attributes_.SetByteSize(0);
buffer_attributes_.SetMemoryType(TRITONSERVER_MEMORY_CPU);
buffer_attributes_.SetMemoryTypeId(0);
allocated_userp_ = nullptr;
RETURN_IF_TRITONSERVER_ERROR(err);
return Status::Success;
}
std::ostream&
operator<<(std::ostream& out, const InferenceResponse& response)
{
out << "[0x" << std::addressof(response) << "] "
<< "response id: " << response.Id() << ", model: " << response.ModelName()
<< ", actual version: " << response.ActualModelVersion() << std::endl;
out << "status:" << response.ResponseStatus().AsString() << std::endl;
out << "outputs:" << std::endl;
for (const auto& output : response.Outputs()) {
out << "[0x" << std::addressof(output) << "] " << output << std::endl;
}
return out;
}
std::ostream&
operator<<(std::ostream& out, const InferenceResponse::Output& output)
{
out << "output: " << output.Name()
<< ", type: " << triton::common::DataTypeToProtocolString(output.DType())
<< ", shape: " << triton::common::DimsListToString(output.Shape());
return out;
}
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <deque>
#include <functional>
#include <string>
#include <vector>
#include "buffer_attributes.h"
#include "constants.h"
#include "infer_parameter.h"
#include "infer_trace.h"
#include "response_allocator.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace triton { namespace core {
class Model;
class InferenceResponse;
//
// An inference response factory.
//
class InferenceResponseFactory {
public:
InferenceResponseFactory() = default;
InferenceResponseFactory(
const std::shared_ptr<Model>& model, const std::string& id,
const ResponseAllocator* allocator, void* alloc_userp,
TRITONSERVER_InferenceResponseCompleteFn_t response_fn,
void* response_userp,
const std::function<void(
std::unique_ptr<InferenceResponse>&&, const uint32_t)>& delegator)
: model_(model), id_(id), allocator_(allocator),
alloc_userp_(alloc_userp), response_fn_(response_fn),
response_userp_(response_userp), response_delegator_(delegator)
{
}
const ResponseAllocator* Allocator() { return allocator_; }
void* AllocatorUserp() { return alloc_userp_; }
Status SetResponseDelegator(
const std::function<void(
std::unique_ptr<InferenceResponse>&&, const uint32_t)>& delegator)
{
response_delegator_ = delegator;
return Status::Success;
}
// Create a new response.
Status CreateResponse(std::unique_ptr<InferenceResponse>* response) const;
// Send a "null" response with 'flags'.
Status SendFlags(const uint32_t flags) const;
#ifdef TRITON_ENABLE_TRACING
const std::shared_ptr<InferenceTraceProxy>& Trace() const { return trace_; }
void SetTrace(const std::shared_ptr<InferenceTraceProxy>& trace)
{
trace_ = trace;
}
void ReleaseTrace() { trace_ = nullptr; }
#endif // TRITON_ENABLE_TRACING
private:
// The model associated with this factory. For normal
// requests/responses this will always be defined and acts to keep
// the model loaded as long as this factory is live. It may be
// nullptr for cases where the model itself created the request
// (like running requests for warmup) and so must protect any uses
// to handle the nullptr case.
std::shared_ptr<Model> model_;
// The ID of the corresponding request that should be included in every
// response. This is a property that can be optionally provided by the user.
std::string id_;
// The response allocator and user pointer. The 'allocator_' is a
// raw pointer because it is owned by the client, and the client is
// responsible for ensuring that the lifetime of the allocator
// extends longer that any request or response that depend on the
// allocator.
const ResponseAllocator* allocator_;
void* alloc_userp_;
// The response callback function and user pointer.
TRITONSERVER_InferenceResponseCompleteFn_t response_fn_;
void* response_userp_;
// Delegator to be invoked on sending responses.
std::function<void(std::unique_ptr<InferenceResponse>&&, const uint32_t)>
response_delegator_;
#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this response.
std::shared_ptr<InferenceTraceProxy> trace_;
#endif // TRITON_ENABLE_TRACING
};
//
// An inference response.
//
class InferenceResponse {
public:
// Output tensor
class Output {
public:
Output(
const std::string& name, const inference::DataType datatype,
const std::vector<int64_t>& shape, const ResponseAllocator* allocator,
void* alloc_userp)
: name_(name), datatype_(datatype), shape_(shape),
allocator_(allocator), alloc_userp_(alloc_userp),
allocated_buffer_(nullptr)
{
}
Output(
const std::string& name, const inference::DataType datatype,
std::vector<int64_t>&& shape, const ResponseAllocator* allocator,
void* alloc_userp)
: name_(name), datatype_(datatype), shape_(std::move(shape)),
allocator_(allocator), alloc_userp_(alloc_userp),
allocated_buffer_(nullptr)
{
}
~Output();
// The name of the output tensor.
const std::string& Name() const { return name_; }
// Data type of the output tensor.
inference::DataType DType() const { return datatype_; }
// The shape of the output tensor.
const std::vector<int64_t>& Shape() const { return shape_; }
BufferAttributes* GetBufferAttributes() { return &buffer_attributes_; }
// Reshape the output tensor. This function must only be called
// for outputs that have respace specified in the model
// configuration.
void Reshape(
const bool has_batch_dim, const inference::ModelOutput* output_config);
// Get information about the buffer allocated for this output
// tensor's data. If no buffer is allocated 'buffer' will return
// nullptr and the other returned values will be undefined.
Status DataBuffer(
const void** buffer, size_t* buffer_byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id,
void** userp) const;
// Allocate the buffer that should be used for this output
// tensor's data. 'buffer' must return a buffer of size
// 'buffer_byte_size'. 'memory_type' acts as both input and
// output. On input gives the buffer memory type preferred by the
// caller and on return holds the actual memory type of
// 'buffer'. 'memory_type_id' acts as both input and output. On
// input gives the buffer memory type id preferred by the caller
// and returns the actual memory type id of 'buffer'. Only a
// single buffer may be allocated for the output at any time, so
// multiple calls to AllocateDataBuffer without intervening
// ReleaseDataBuffer call will result in an error.
Status AllocateDataBuffer(
void** buffer, const size_t buffer_byte_size,
TRITONSERVER_MemoryType* memory_type, int64_t* memory_type_id);
// Release the buffer that was previously allocated by
// AllocateDataBuffer(). Do nothing if AllocateDataBuffer() has
// not been called.
Status ReleaseDataBuffer();
private:
DISALLOW_COPY_AND_ASSIGN(Output);
friend std::ostream& operator<<(
std::ostream& out, const InferenceResponse::Output& output);
std::string name_;
inference::DataType datatype_;
std::vector<int64_t> shape_;
// The response allocator and user pointer.
const ResponseAllocator* allocator_;
void* alloc_userp_;
// Information about the buffer allocated by
// AllocateDataBuffer(). This information is needed by
// DataBuffer() and ReleaseDataBuffer().
void* allocated_buffer_;
BufferAttributes buffer_attributes_;
void* allocated_userp_;
};
// InferenceResponse
InferenceResponse(
const std::shared_ptr<Model>& model, const std::string& id,
const ResponseAllocator* allocator, void* alloc_userp,
TRITONSERVER_InferenceResponseCompleteFn_t response_fn,
void* response_userp,
const std::function<void(
std::unique_ptr<InferenceResponse>&&, const uint32_t)>& delegator);
// "null" InferenceResponse is a special instance of InferenceResponse which
// contains minimal information for calling InferenceResponse::Send,
// InferenceResponse::NullResponse. nullptr will be passed as response in
// 'response_fn'.
InferenceResponse(
TRITONSERVER_InferenceResponseCompleteFn_t response_fn,
void* response_userp);
const std::string& Id() const { return id_; }
const std::string& ModelName() const;
int64_t ActualModelVersion() const;
const Status& ResponseStatus() const { return status_; }
// The response parameters.
const std::deque<InferenceParameter>& Parameters() const
{
return parameters_;
}
// Add an parameter to the response.
Status AddParameter(const char* name, const char* value);
Status AddParameter(const char* name, const int64_t value);
Status AddParameter(const char* name, const bool value);
// The response outputs.
const std::deque<Output>& Outputs() const { return outputs_; }
// Add an output to the response. If 'output' is non-null
// return a pointer to the newly added output.
Status AddOutput(
const std::string& name, const inference::DataType datatype,
const std::vector<int64_t>& shape, Output** output = nullptr);
Status AddOutput(
const std::string& name, const inference::DataType datatype,
std::vector<int64_t>&& shape, Output** output = nullptr);
// Get the classification label associated with an output. Return
// 'label' == nullptr if no label.
Status ClassificationLabel(
const Output& output, const uint32_t class_index,
const char** label) const;
// Send the response with success status. Calling this function
// releases ownership of the response object and gives it to the
// callback function.
static Status Send(
std::unique_ptr<InferenceResponse>&& response, const uint32_t flags);
// Send the response with explicit status. Calling this function
// releases ownership of the response object and gives it to the
// callback function.
static Status SendWithStatus(
std::unique_ptr<InferenceResponse>&& response, const uint32_t flags,
const Status& status);
#ifdef TRITON_ENABLE_TRACING
const std::shared_ptr<InferenceTraceProxy>& Trace() const { return trace_; }
void SetTrace(const std::shared_ptr<InferenceTraceProxy>& trace)
{
trace_ = trace;
}
void ReleaseTrace() { trace_ = nullptr; }
#endif // TRITON_ENABLE_TRACING
private:
DISALLOW_COPY_AND_ASSIGN(InferenceResponse);
friend std::ostream& operator<<(
std::ostream& out, const InferenceResponse& response);
#ifdef TRITON_ENABLE_TRACING
Status TraceOutputTensors(
TRITONSERVER_InferenceTraceActivity activity, const std::string& msg);
#endif // TRITON_ENABLE_TRACING
// The model associated with this factory. For normal
// requests/responses this will always be defined and acts to keep
// the model loaded as long as this factory is live. It may be
// nullptr for cases where the model itself created the request
// (like running requests for warmup) and so must protect any uses
// to handle the nullptr case.
std::shared_ptr<Model> model_;
// The ID of the corresponding request that should be included in
// every response.
std::string id_;
// Error status for the response.
Status status_;
// The parameters of the response. Use a deque so that there is no
// reallocation.
std::deque<InferenceParameter> parameters_;
// The result tensors. Use a deque so that there is no reallocation.
std::deque<Output> outputs_;
// The response allocator and user pointer.
const ResponseAllocator* allocator_;
void* alloc_userp_;
// The response callback function and user pointer.
TRITONSERVER_InferenceResponseCompleteFn_t response_fn_;
void* response_userp_;
// Delegator to be invoked on sending responses.
std::function<void(std::unique_ptr<InferenceResponse>&&, const uint32_t)>
response_delegator_;
bool null_response_;
#ifdef TRITON_ENABLE_TRACING
// Inference trace associated with this response.
std::shared_ptr<InferenceTraceProxy> trace_;
#endif // TRITON_ENABLE_TRACING
};
std::ostream& operator<<(std::ostream& out, const InferenceResponse& response);
std::ostream& operator<<(
std::ostream& out, const InferenceResponse::Output& output);
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_stats.h"
#include <time.h>
#include "metric_model_reporter.h"
#include "metrics.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
#ifdef TRITON_ENABLE_STATS
void
InferenceStatsAggregator::UpdateFailure(
MetricModelReporter* metric_reporter, const uint64_t request_start_ns,
const uint64_t request_end_ns)
{
std::lock_guard<std::mutex> lock(mu_);
infer_stats_.failure_count_++;
infer_stats_.failure_duration_ns_ += (request_end_ns - request_start_ns);
#ifdef TRITON_ENABLE_METRICS
if (metric_reporter != nullptr) {
metric_reporter->MetricInferenceFailure().Increment(1);
}
#endif // TRITON_ENABLE_METRICS
}
void
InferenceStatsAggregator::UpdateSuccess(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t request_start_ns, const uint64_t queue_start_ns,
const uint64_t compute_start_ns, const uint64_t compute_input_end_ns,
const uint64_t compute_output_start_ns, const uint64_t compute_end_ns,
const uint64_t request_end_ns)
{
const uint64_t compute_input_duration_ns =
compute_input_end_ns - compute_start_ns;
const uint64_t compute_infer_duration_ns =
compute_output_start_ns - compute_input_end_ns;
const uint64_t compute_output_duration_ns =
compute_end_ns - compute_output_start_ns;
UpdateSuccessWithDuration(
metric_reporter, batch_size, request_start_ns, queue_start_ns,
compute_start_ns, request_end_ns, compute_input_duration_ns,
compute_infer_duration_ns, compute_output_duration_ns);
}
void
InferenceStatsAggregator::UpdateSuccessWithDuration(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t request_start_ns, const uint64_t queue_start_ns,
const uint64_t compute_start_ns, const uint64_t request_end_ns,
const uint64_t compute_input_duration_ns,
const uint64_t compute_infer_duration_ns,
const uint64_t compute_output_duration_ns)
{
const uint64_t request_duration_ns = request_end_ns - request_start_ns;
const uint64_t queue_duration_ns = compute_start_ns - queue_start_ns;
std::lock_guard<std::mutex> lock(mu_);
inference_count_ += batch_size;
infer_stats_.success_count_++;
infer_stats_.request_duration_ns_ += request_duration_ns;
infer_stats_.queue_duration_ns_ += queue_duration_ns;
infer_stats_.compute_input_duration_ns_ += compute_input_duration_ns;
infer_stats_.compute_infer_duration_ns_ += compute_infer_duration_ns;
infer_stats_.compute_output_duration_ns_ += compute_output_duration_ns;
#ifdef TRITON_ENABLE_METRICS
if (metric_reporter != nullptr) {
metric_reporter->MetricInferenceSuccess().Increment(1);
metric_reporter->MetricInferenceCount().Increment(batch_size);
metric_reporter->MetricInferenceRequestDuration().Increment(
request_duration_ns / 1000);
metric_reporter->MetricInferenceQueueDuration().Increment(
queue_duration_ns / 1000);
metric_reporter->MetricInferenceComputeInputDuration().Increment(
compute_input_duration_ns / 1000);
metric_reporter->MetricInferenceComputeInferDuration().Increment(
compute_infer_duration_ns / 1000);
metric_reporter->MetricInferenceComputeOutputDuration().Increment(
compute_output_duration_ns / 1000);
}
#endif // TRITON_ENABLE_METRICS
}
// Currently cache hits will not go to the inference backend where metrics
// are typically updated, so this method allows us to update relevant metrics
// from a metric reporter rather than going through the backend.
void
InferenceStatsAggregator::UpdateSuccessCacheHit(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t request_start_ns, const uint64_t queue_start_ns,
const uint64_t cache_lookup_start_ns, const uint64_t request_end_ns,
const uint64_t cache_hit_lookup_duration_ns)
{
const uint64_t request_duration_ns = request_end_ns - request_start_ns;
const uint64_t queue_duration_ns = cache_lookup_start_ns - queue_start_ns;
std::lock_guard<std::mutex> lock(mu_);
infer_stats_.success_count_++;
infer_stats_.request_duration_ns_ += request_duration_ns;
infer_stats_.queue_duration_ns_ += queue_duration_ns;
infer_stats_.cache_hit_count_++;
infer_stats_.cache_hit_lookup_duration_ns_ += cache_hit_lookup_duration_ns;
#ifdef TRITON_ENABLE_METRICS
if (metric_reporter != nullptr) {
metric_reporter->MetricInferenceSuccess().Increment(1);
metric_reporter->MetricInferenceRequestDuration().Increment(
request_duration_ns / 1000);
metric_reporter->MetricInferenceQueueDuration().Increment(
queue_duration_ns / 1000);
metric_reporter->MetricCacheHitCount().Increment(1);
metric_reporter->MetricCacheHitLookupDuration().Increment(
cache_hit_lookup_duration_ns / 1000);
}
#endif // TRITON_ENABLE_METRICS
}
// Cache misses will go to the inference backend where metrics are typically
// updated, but cache insertion happens after the inference backend finishes.
// So we use this method to update cache miss stats and adjust the request
// duration to include cache insertion time.
void
InferenceStatsAggregator::UpdateSuccessCacheMiss(
MetricModelReporter* metric_reporter,
const uint64_t cache_miss_lookup_duration_ns,
const uint64_t cache_miss_insertion_duration_ns)
{
std::lock_guard<std::mutex> lock(mu_);
const uint64_t cache_miss_duration_ns =
cache_miss_lookup_duration_ns + cache_miss_insertion_duration_ns;
infer_stats_.request_duration_ns_ += cache_miss_duration_ns;
infer_stats_.cache_miss_count_++;
infer_stats_.cache_miss_lookup_duration_ns_ += cache_miss_lookup_duration_ns;
infer_stats_.cache_miss_insertion_duration_ns_ +=
cache_miss_insertion_duration_ns;
#ifdef TRITON_ENABLE_METRICS
if (metric_reporter != nullptr) {
// Add cache insertion time to request duration since insertion
// happens after inference backend sets the request duration, and
// cache lookup time was already included before the inference backend
// was called
metric_reporter->MetricInferenceRequestDuration().Increment(
cache_miss_duration_ns / 1000);
metric_reporter->MetricCacheMissCount().Increment(1);
metric_reporter->MetricCacheMissLookupDuration().Increment(
cache_miss_lookup_duration_ns / 1000);
metric_reporter->MetricCacheMissInsertionDuration().Increment(
cache_miss_insertion_duration_ns / 1000);
}
#endif // TRITON_ENABLE_METRICS
}
void
InferenceStatsAggregator::UpdateInferBatchStats(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t compute_start_ns, const uint64_t compute_input_end_ns,
const uint64_t compute_output_start_ns, const uint64_t compute_end_ns)
{
auto compute_input_duration_ns = (compute_input_end_ns - compute_start_ns);
auto compute_infer_duration_ns =
(compute_output_start_ns - compute_input_end_ns);
auto compute_output_duration_ns = (compute_end_ns - compute_output_start_ns);
UpdateInferBatchStatsWithDuration(
metric_reporter, batch_size, compute_input_duration_ns,
compute_infer_duration_ns, compute_output_duration_ns);
}
void
InferenceStatsAggregator::UpdateInferBatchStatsWithDuration(
MetricModelReporter* metric_reporter, size_t batch_size,
const uint64_t compute_input_duration_ns,
const uint64_t compute_infer_duration_ns,
const uint64_t compute_output_duration_ns)
{
uint64_t inference_ms =
std::chrono::duration_cast<std::chrono::milliseconds>(
std::chrono::system_clock::now().time_since_epoch())
.count();
std::lock_guard<std::mutex> lock(mu_);
if (inference_ms > last_inference_ms_) {
last_inference_ms_ = inference_ms;
}
execution_count_++;
auto it = batch_stats_.find(batch_size);
if (it == batch_stats_.end()) {
it = batch_stats_.emplace(batch_size, InferBatchStats()).first;
}
it->second.count_++;
it->second.compute_input_duration_ns_ += compute_input_duration_ns;
it->second.compute_infer_duration_ns_ += compute_infer_duration_ns;
it->second.compute_output_duration_ns_ += compute_output_duration_ns;
#ifdef TRITON_ENABLE_METRICS
if (metric_reporter != nullptr) {
metric_reporter->MetricInferenceExecutionCount().Increment(1);
}
#endif // TRITON_ENABLE_METRICS
}
#endif // TRITON_ENABLE_STATS
}} // namespace triton::core
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <time.h>
#include <map>
#include <memory>
#include <mutex>
#include <vector>
#include "constants.h"
#include "infer_response.h"
#include "status.h"
#include "tritonserver_apis.h"
namespace triton { namespace core {
class MetricModelReporter;
//
// InferenceStatsAggregator
//
// A statistics aggregator.
//
class InferenceStatsAggregator {
#ifdef TRITON_ENABLE_STATS
public:
struct InferStats {
InferStats()
: failure_count_(0), failure_duration_ns_(0), success_count_(0),
request_duration_ns_(0), queue_duration_ns_(0),
compute_input_duration_ns_(0), compute_infer_duration_ns_(0),
compute_output_duration_ns_(0), cache_hit_count_(0),
cache_hit_lookup_duration_ns_(0), cache_miss_count_(0),
cache_miss_lookup_duration_ns_(0),
cache_miss_insertion_duration_ns_(0)
{
}
uint64_t failure_count_;
uint64_t failure_duration_ns_;
uint64_t success_count_;
uint64_t request_duration_ns_;
uint64_t queue_duration_ns_;
uint64_t compute_input_duration_ns_;
uint64_t compute_infer_duration_ns_;
uint64_t compute_output_duration_ns_;
// Cache hit stats
uint64_t cache_hit_count_;
uint64_t cache_hit_lookup_duration_ns_;
// Cache miss stats
uint64_t cache_miss_count_;
uint64_t cache_miss_lookup_duration_ns_;
uint64_t cache_miss_insertion_duration_ns_;
};
struct InferBatchStats {
InferBatchStats()
: count_(0), compute_input_duration_ns_(0),
compute_infer_duration_ns_(0), compute_output_duration_ns_(0)
{
}
uint64_t count_;
uint64_t compute_input_duration_ns_;
uint64_t compute_infer_duration_ns_;
uint64_t compute_output_duration_ns_;
};
// Create an aggregator for model statistics
InferenceStatsAggregator()
: last_inference_ms_(0), inference_count_(0), execution_count_(0)
{
}
uint64_t LastInferenceMs() const { return last_inference_ms_; }
uint64_t InferenceCount() const { return inference_count_; }
uint64_t ExecutionCount() const { return execution_count_; }
const InferStats& ImmutableInferStats() const { return infer_stats_; }
const std::map<size_t, InferBatchStats>& ImmutableInferBatchStats() const
{
return batch_stats_;
}
// Add durations to Infer stats for a failed inference request.
void UpdateFailure(
MetricModelReporter* metric_reporter, const uint64_t request_start_ns,
const uint64_t request_end_ns);
// Add durations to infer stats for a successful inference request.
void UpdateSuccess(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t request_start_ns, const uint64_t queue_start_ns,
const uint64_t compute_start_ns, const uint64_t compute_input_end_ns,
const uint64_t compute_output_start_ns, const uint64_t compute_end_ns,
const uint64_t request_end_ns);
// Add durations to infer stats for a successful inference request.
void UpdateSuccessWithDuration(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t request_start_ns, const uint64_t queue_start_ns,
const uint64_t compute_start_ns, const uint64_t request_end_ns,
const uint64_t compute_input_duration_ns,
const uint64_t compute_infer_duration_ns,
const uint64_t compute_output_duration_ns);
// Add durations to infer stats for a successful cached response.
void UpdateSuccessCacheHit(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t request_start_ns, const uint64_t queue_start_ns,
const uint64_t cache_lookup_start_ns, const uint64_t request_end_ns,
const uint64_t cache_hit_lookup_duration_ns);
// Add durations to infer stats for a cache miss and update request duration
// to account for cache insertion after backend computes the response.
void UpdateSuccessCacheMiss(
MetricModelReporter* metric_reporter,
const uint64_t cache_miss_lookup_duration_ns,
const uint64_t cache_miss_insertion_duration_ns);
// Add durations to batch infer stats for a batch execution.
// 'success_request_count' is the number of sucess requests in the
// batch that have infer_stats attached.
void UpdateInferBatchStats(
MetricModelReporter* metric_reporter, const size_t batch_size,
const uint64_t compute_start_ns, const uint64_t compute_input_end_ns,
const uint64_t compute_output_start_ns, const uint64_t compute_end_ns);
// Add durations to batch infer stats for a batch execution.
// 'success_request_count' is the number of sucess requests in the
// batch that have infer_stats attached.
void UpdateInferBatchStatsWithDuration(
MetricModelReporter* metric_reporter, size_t batch_size,
const uint64_t compute_input_duration_ns,
const uint64_t compute_infer_duration_ns,
const uint64_t compute_output_duration_ns);
private:
std::mutex mu_;
uint64_t last_inference_ms_;
uint64_t inference_count_;
uint64_t execution_count_;
InferStats infer_stats_;
std::map<size_t, InferBatchStats> batch_stats_;
#endif // TRITON_ENABLE_STATS
};
//
// Macros to set infer stats.
//
#ifdef TRITON_ENABLE_STATS
#define INFER_STATS_SET_TIMESTAMP(TS_NS) \
{ \
TS_NS = std::chrono::duration_cast<std::chrono::nanoseconds>( \
std::chrono::steady_clock::now().time_since_epoch()) \
.count(); \
}
#define INFER_STATS_DECL_TIMESTAMP(TS_NS) \
uint64_t TS_NS; \
INFER_STATS_SET_TIMESTAMP(TS_NS);
#else
#define INFER_STATS_DECL_TIMESTAMP(TS_NS)
#define INFER_STATS_SET_TIMESTAMP(TS_NS)
#endif // TRITON_ENABLE_STATS
}} // namespace triton::core
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "infer_trace.h"
namespace triton { namespace core {
#ifdef TRITON_ENABLE_TRACING
// Start the trace id at 1, because id 0 is reserved to indicate no
// parent.
std::atomic<uint64_t> InferenceTrace::next_id_(1);
InferenceTrace*
InferenceTrace::SpawnChildTrace()
{
InferenceTrace* trace = new InferenceTrace(
level_, id_, activity_fn_, tensor_activity_fn_, release_fn_, userp_);
return trace;
}
void
InferenceTrace::Release()
{
release_fn_(reinterpret_cast<TRITONSERVER_InferenceTrace*>(this), userp_);
}
std::shared_ptr<InferenceTraceProxy>
InferenceTraceProxy::SpawnChildTrace()
{
std::shared_ptr<InferenceTraceProxy> strace_proxy =
std::make_shared<InferenceTraceProxy>(trace_->SpawnChildTrace());
return strace_proxy;
}
#endif // TRITON_ENABLE_TRACING
}} // namespace triton::core
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <atomic>
#include <chrono>
#include <memory>
#include "constants.h"
#include "status.h"
#include "tritonserver_apis.h"
namespace triton { namespace core {
#ifdef TRITON_ENABLE_TRACING
//
// InferenceTrace
//
// Interface to TRITONSERVER_InferenceTrace to report trace events.
//
class InferenceTrace {
public:
InferenceTrace(
const TRITONSERVER_InferenceTraceLevel level, const uint64_t parent_id,
TRITONSERVER_InferenceTraceActivityFn_t activity_fn,
TRITONSERVER_InferenceTraceTensorActivityFn_t tensor_activity_fn,
TRITONSERVER_InferenceTraceReleaseFn_t release_fn, void* userp)
: level_(level), id_(next_id_++), parent_id_(parent_id),
activity_fn_(activity_fn), tensor_activity_fn_(tensor_activity_fn),
release_fn_(release_fn), userp_(userp)
{
}
InferenceTrace* SpawnChildTrace();
int64_t Id() const { return id_; }
int64_t ParentId() const { return parent_id_; }
const std::string& ModelName() const { return model_name_; }
int64_t ModelVersion() const { return model_version_; }
void SetModelName(const std::string& n) { model_name_ = n; }
void SetModelVersion(int64_t v) { model_version_ = v; }
// Report trace activity.
void Report(
const TRITONSERVER_InferenceTraceActivity activity, uint64_t timestamp_ns)
{
if ((level_ & TRITONSERVER_TRACE_LEVEL_TIMESTAMPS) > 0) {
activity_fn_(
reinterpret_cast<TRITONSERVER_InferenceTrace*>(this), activity,
timestamp_ns, userp_);
}
}
// Report trace activity at the current time.
void ReportNow(const TRITONSERVER_InferenceTraceActivity activity)
{
if ((level_ & TRITONSERVER_TRACE_LEVEL_TIMESTAMPS) > 0) {
Report(
activity, std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count());
}
}
// Report tensor trace activity.
void ReportTensor(
const TRITONSERVER_InferenceTraceActivity activity, const char* name,
TRITONSERVER_DataType datatype, const void* base, size_t byte_size,
const int64_t* shape, uint64_t dim_count,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id)
{
if ((level_ & TRITONSERVER_TRACE_LEVEL_TENSORS) > 0) {
tensor_activity_fn_(
reinterpret_cast<TRITONSERVER_InferenceTrace*>(this), activity, name,
datatype, base, byte_size, shape, dim_count, memory_type,
memory_type_id, userp_);
}
}
// Release the trace. Call the trace release callback.
void Release();
private:
const TRITONSERVER_InferenceTraceLevel level_;
const uint64_t id_;
const uint64_t parent_id_;
TRITONSERVER_InferenceTraceActivityFn_t activity_fn_;
TRITONSERVER_InferenceTraceTensorActivityFn_t tensor_activity_fn_;
TRITONSERVER_InferenceTraceReleaseFn_t release_fn_;
void* userp_;
std::string model_name_;
int64_t model_version_;
// Maintain next id statically so that trace id is unique even
// across traces
static std::atomic<uint64_t> next_id_;
};
//
// InferenceTraceProxy
//
// Object attached as shared_ptr to InferenceRequest and
// InferenceResponse(s) being traced as part of a single inference
// request.
//
class InferenceTraceProxy {
public:
InferenceTraceProxy(InferenceTrace* trace) : trace_(trace) {}
~InferenceTraceProxy() { trace_->Release(); }
int64_t Id() const { return trace_->Id(); }
int64_t ParentId() const { return trace_->ParentId(); }
const std::string& ModelName() const { return trace_->ModelName(); }
int64_t ModelVersion() const { return trace_->ModelVersion(); }
void SetModelName(const std::string& n) { trace_->SetModelName(n); }
void SetModelVersion(int64_t v) { trace_->SetModelVersion(v); }
void Report(
const TRITONSERVER_InferenceTraceActivity activity, uint64_t timestamp_ns)
{
trace_->Report(activity, timestamp_ns);
}
void ReportNow(const TRITONSERVER_InferenceTraceActivity activity)
{
trace_->ReportNow(activity);
}
void ReportTensor(
const TRITONSERVER_InferenceTraceActivity activity, const char* name,
TRITONSERVER_DataType datatype, const void* base, size_t byte_size,
const int64_t* shape, uint64_t dim_count,
TRITONSERVER_MemoryType memory_type, int64_t memory_type_id)
{
trace_->ReportTensor(
activity, name, datatype, base, byte_size, shape, dim_count,
memory_type, memory_type_id);
}
std::shared_ptr<InferenceTraceProxy> SpawnChildTrace();
private:
InferenceTrace* trace_;
};
#endif // TRITON_ENABLE_TRACING
//
// Macros to generate trace activity
//
#ifdef TRITON_ENABLE_TRACING
#define INFER_TRACE_ACTIVITY(T, A, TS_NS) \
{ \
const auto& trace = (T); \
const auto ts_ns = (TS_NS); \
if (trace != nullptr) { \
trace->Report(A, ts_ns); \
} \
}
#define INFER_TRACE_ACTIVITY_NOW(T, A) \
{ \
const auto& trace = (T); \
if (trace != nullptr) { \
trace->ReportNow(A); \
} \
}
#define INFER_TRACE_TENSOR_ACTIVITY(T, A, N, D, BA, BY, S, DI, MT, MTI) \
{ \
const auto& trace = (T); \
if (trace != nullptr) { \
trace->ReportTensor(A, N, D, BA, BY, S, DI, MT, MTI); \
} \
}
#else
#define INFER_TRACE_ACTIVITY(T, A, TS_NS)
#define INFER_TRACE_ACTIVITY_NOW(T, A)
#define INFER_TRACE_TENSOR_ACTIVITY(T, A, N, D, BA, BY, S, DI, MT, MTI)
#endif // TRITON_ENABLE_TRACING
}} // namespace triton::core
// Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "instance_queue.h"
#include "triton/common/logging.h"
namespace triton { namespace core {
InstanceQueue::InstanceQueue(size_t max_batch_size, uint64_t max_queue_delay_ns)
: max_batch_size_(max_batch_size), max_queue_delay_ns_(max_queue_delay_ns)
{
}
size_t
InstanceQueue::Size()
{
return payload_queue_.size();
}
bool
InstanceQueue::Empty()
{
return payload_queue_.empty();
}
void
InstanceQueue::Enqueue(const std::shared_ptr<Payload>& payload)
{
payload_queue_.push_back(payload);
}
void
InstanceQueue::Dequeue(
std::shared_ptr<Payload>* payload,
std::vector<std::shared_ptr<Payload>>* merged_payloads)
{
*payload = payload_queue_.front();
payload_queue_.pop_front();
{
std::lock_guard<std::mutex> exec_lock(*((*payload)->GetExecMutex()));
(*payload)->SetState(Payload::State::EXECUTING);
if ((!payload_queue_.empty()) && (max_queue_delay_ns_ > 0) &&
(max_batch_size_ > 1) && (!(*payload)->IsSaturated())) {
bool continue_merge;
do {
continue_merge = false;
uint64_t now_ns =
std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::steady_clock::now().time_since_epoch())
.count();
size_t batch_size = (*payload)->BatchSize();
if ((!payload_queue_.empty()) &&
(!payload_queue_.front()->IsSaturated()) &&
(now_ns - payload_queue_.front()->BatcherStartNs()) >
max_queue_delay_ns_) {
std::lock_guard<std::mutex> exec_lock(
*(payload_queue_.front()->GetExecMutex()));
payload_queue_.front()->SetState(Payload::State::EXECUTING);
size_t front_batch_size = payload_queue_.front()->BatchSize();
if ((batch_size + front_batch_size) <= max_batch_size_) {
const auto& status =
(*payload)->MergePayload(payload_queue_.front());
if (status.IsOk()) {
merged_payloads->push_back(payload_queue_.front());
payload_queue_.pop_front();
continue_merge = true;
}
}
}
} while (continue_merge);
}
}
}
}} // namespace triton::core
// Copyright 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "payload.h"
namespace triton { namespace core {
//
// InstanceQueue
//
// A queue implementation holding Payloads ready to be scheduled on
// model instance.
class InstanceQueue {
public:
explicit InstanceQueue(size_t max_batch_size, uint64_t max_queue_delay_ns);
size_t Size();
bool Empty();
void Enqueue(const std::shared_ptr<Payload>& payload);
void Dequeue(
std::shared_ptr<Payload>* payload,
std::vector<std::shared_ptr<Payload>>* merged_payloads);
private:
size_t max_batch_size_;
uint64_t max_queue_delay_ns_;
std::deque<std::shared_ptr<Payload>> payload_queue_;
std::shared_ptr<Payload> staged_payload_;
std::mutex mu_;
};
}} // namespace triton::core
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment