Commit c68e1835 authored by lijian6's avatar lijian6
Browse files

Initial commit

parents
Pipeline #561 failed with stages
in 0 seconds
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# A function that creates CPP sources from proto files.
function(COMPILE_PROTO USE_GRPC PROTO_PATH OUT_PATH SRCS HDRS)
# Checking args.
if(NOT ARGN)
message(SEND_ERROR "Error: COMPILE_PROTO() called without any proto files")
return()
endif()
# To collect paths to created sources and headers.
set(${SRCS})
set(${HDRS})
# Getting actual absolute paths to all protos location and output directory.
get_filename_component(ABS_PROTO_PATH "${PROTO_PATH}" ABSOLUTE)
get_filename_component(ABS_OUT_PATH "${OUT_PATH}" ABSOLUTE)
# Launching sources generation for all proto files.
foreach(FIL ${ARGN})
# Getting the absolute path and filename without extension for the current proto file.
get_filename_component(ABS_FIL "${FIL}" ABSOLUTE)
get_filename_component(FIL_WE "${FIL}" NAME_WE)
# Getting the relative dir of the proto file (relative to the protos root dir).
file(RELATIVE_PATH REL_FIL_TO_PROTO "${ABS_PROTO_PATH}" "${ABS_FIL}")
get_filename_component(REL_DIR_TO_PROTO "${REL_FIL_TO_PROTO}" DIRECTORY)
# Preparing a path to label created sources from proto.
set(COMPILED_NAME_TEMPLATE "${ABS_OUT_PATH}/${REL_DIR_TO_PROTO}/${FIL_WE}")
# Firing sources generation command with gRPC application.
if(${USE_GRPC})
set(_GRPC_CPP_PLUGIN_EXECUTABLE $<TARGET_FILE:gRPC::grpc_cpp_plugin>)
# Marking created files for CMake.
list(APPEND ${SRCS} "${COMPILED_NAME_TEMPLATE}.grpc.pb.cc")
list(APPEND ${HDRS} "${COMPILED_NAME_TEMPLATE}.grpc.pb.h")
# Launching proto compilation command.
add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory "${ABS_OUT_PATH}"
OUTPUT
"${COMPILED_NAME_TEMPLATE}.grpc.pb.cc"
"${COMPILED_NAME_TEMPLATE}.grpc.pb.h"
COMMAND
${Protobuf_PROTOC_EXECUTABLE}
ARGS
--grpc_out=${ABS_OUT_PATH}
--plugin=protoc-gen-grpc=${_GRPC_CPP_PLUGIN_EXECUTABLE}
--proto_path=${ABS_PROTO_PATH}
${ABS_FIL}
DEPENDS
${ABS_FIL} ${Protobuf_PROTOC_EXECUTABLE}
COMMENT
"Running gRPC C++ protocol buffer compiler on ${FIL}"
VERBATIM)
# Without gRPC.
else()
list(APPEND ${SRCS} "${COMPILED_NAME_TEMPLATE}.pb.cc")
list(APPEND ${HDRS} "${COMPILED_NAME_TEMPLATE}.pb.h")
add_custom_command(
COMMAND ${CMAKE_COMMAND} -E make_directory "${ABS_OUT_PATH}"
OUTPUT
"${COMPILED_NAME_TEMPLATE}.pb.cc"
"${COMPILED_NAME_TEMPLATE}.pb.h"
COMMAND
${Protobuf_PROTOC_EXECUTABLE}
ARGS
--cpp_out=${ABS_OUT_PATH}
--proto_path=${ABS_PROTO_PATH}
${ABS_FIL}
DEPENDS
${ABS_FIL} ${Protobuf_PROTOC_EXECUTABLE}
COMMENT
"Running C++ protocol buffer compiler on ${FIL}"
VERBATIM)
endif()
endforeach()
# Returning generated sources list.
set_source_files_properties(${${SRCS}} ${${HDRS}} PROPERTIES GENERATED TRUE)
set(${SRCS} ${${SRCS}} PARENT_SCOPE)
set(${HDRS} ${${HDRS}} PARENT_SCOPE)
endfunction()
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "tfserve_client_backend.h"
#include "json_utils.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tfserving {
//==============================================================================
Error
TFServeClientBackend::Create(
const std::string& url, const ProtocolType protocol,
const grpc_compression_algorithm compression_algorithm,
std::shared_ptr<Headers> http_headers, const bool verbose,
std::unique_ptr<ClientBackend>* client_backend)
{
if (protocol == ProtocolType::HTTP) {
return Error(
"perf_analyzer does not support http protocol with TF serving");
}
std::unique_ptr<TFServeClientBackend> tfserve_client_backend(
new TFServeClientBackend(compression_algorithm, http_headers));
RETURN_IF_CB_ERROR(GrpcClient::Create(
&(tfserve_client_backend->grpc_client_), url, verbose));
*client_backend = std::move(tfserve_client_backend);
return Error::Success;
}
Error
TFServeClientBackend::ModelMetadata(
rapidjson::Document* model_metadata, const std::string& model_name,
const std::string& model_version)
{
tensorflow::serving::GetModelMetadataResponse metadata_proto;
RETURN_IF_CB_ERROR(grpc_client_->ModelMetadata(
&metadata_proto, model_name, model_version, *http_headers_));
std::string metadata;
::google::protobuf::util::JsonPrintOptions options;
options.preserve_proto_field_names = true;
options.always_print_primitive_fields = true;
::google::protobuf::util::MessageToJsonString(
metadata_proto, &metadata, options);
RETURN_IF_TRITON_ERROR(tc::ParseJson(model_metadata, metadata));
return Error::Success;
}
Error
TFServeClientBackend::Infer(
cb::InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
tfs::InferResult* tfserve_result;
RETURN_IF_CB_ERROR(grpc_client_->Infer(
&tfserve_result, options, inputs, outputs, *http_headers_,
compression_algorithm_));
*result = new TFServeInferResult(tfserve_result);
return Error::Success;
}
Error
TFServeClientBackend::AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
auto wrapped_callback = [callback](tfs::InferResult* client_result) {
cb::InferResult* result = new TFServeInferResult(client_result);
callback(result);
};
RETURN_IF_CB_ERROR(grpc_client_->AsyncInfer(
wrapped_callback, options, inputs, outputs, *http_headers_,
compression_algorithm_));
return Error::Success;
}
Error
TFServeClientBackend::ClientInferStat(InferStat* infer_stat)
{
// Reusing the common library utilities to collect and report the
// client side statistics.
tc::InferStat client_infer_stat;
RETURN_IF_TRITON_ERROR(grpc_client_->ClientInferStat(&client_infer_stat));
ParseInferStat(client_infer_stat, infer_stat);
return Error::Success;
}
void
TFServeClientBackend::ParseInferStat(
const tc::InferStat& tfserve_infer_stat, InferStat* infer_stat)
{
infer_stat->completed_request_count =
tfserve_infer_stat.completed_request_count;
infer_stat->cumulative_total_request_time_ns =
tfserve_infer_stat.cumulative_total_request_time_ns;
infer_stat->cumulative_send_time_ns =
tfserve_infer_stat.cumulative_send_time_ns;
infer_stat->cumulative_receive_time_ns =
tfserve_infer_stat.cumulative_receive_time_ns;
}
//==============================================================================
Error
TFServeInferRequestedOutput::Create(
InferRequestedOutput** infer_output, const std::string& name)
{
TFServeInferRequestedOutput* local_infer_output =
new TFServeInferRequestedOutput(name);
tc::InferRequestedOutput* tfserve_infer_output;
RETURN_IF_TRITON_ERROR(
tc::InferRequestedOutput::Create(&tfserve_infer_output, name));
local_infer_output->output_.reset(tfserve_infer_output);
*infer_output = local_infer_output;
return Error::Success;
}
TFServeInferRequestedOutput::TFServeInferRequestedOutput(
const std::string& name)
: InferRequestedOutput(BackendKind::TENSORFLOW_SERVING, name)
{
}
//==============================================================================
TFServeInferResult::TFServeInferResult(tfs::InferResult* result)
{
result_.reset(result);
}
Error
TFServeInferResult::Id(std::string* id) const
{
id->clear();
return Error::Success;
}
Error
TFServeInferResult::RequestStatus() const
{
RETURN_IF_CB_ERROR(result_->RequestStatus());
return Error::Success;
}
Error
TFServeInferResult::RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const
{
return Error(
"Output retrieval is not currently supported for TFS client backend");
}
//==============================================================================
}}}} // namespace triton::perfanalyzer::clientbackend::tfserving
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include "../../perf_utils.h"
#include "../client_backend.h"
#include "tfserve_grpc_client.h"
#define RETURN_IF_TRITON_ERROR(S) \
do { \
const tc::Error& status__ = (S); \
if (!status__.IsOk()) { \
return Error(status__.Message()); \
} \
} while (false)
namespace tc = triton::client;
namespace cb = triton::perfanalyzer::clientbackend;
namespace tfs = triton::perfanalyzer::clientbackend::tfserving;
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tfserving {
//==============================================================================
/// TFServeClientBackend is used to generate load on the TF serving instance
///
class TFServeClientBackend : public ClientBackend {
public:
/// Create a TFserving client backend which can be used to interact with the
/// server.
/// \param url The inference server url and port.
/// \param protocol The protocol type used.
/// \param compression_algorithm The compression algorithm to be used
/// on the grpc requests.
/// \param http_headers Map of HTTP headers. The map key/value indicates
/// the header name/value.
/// \param verbose Enables the verbose mode.
/// \param client_backend Returns a new TFServeClientBackend
/// object.
/// \return Error object indicating success or failure.
static Error Create(
const std::string& url, const ProtocolType protocol,
const grpc_compression_algorithm compression_algorithm,
std::shared_ptr<Headers> http_headers, const bool verbose,
std::unique_ptr<ClientBackend>* client_backend);
/// See ClientBackend::ModelMetadata()
Error ModelMetadata(
rapidjson::Document* model_metadata, const std::string& model_name,
const std::string& model_version) override;
/// See ClientBackend::Infer()
Error Infer(
cb::InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override;
/// See ClientBackend::AsyncInfer()
Error AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override;
/// See ClientBackend::ClientInferStat()
Error ClientInferStat(InferStat* infer_stat) override;
private:
TFServeClientBackend(
const grpc_compression_algorithm compression_algorithm,
std::shared_ptr<Headers> http_headers)
: ClientBackend(BackendKind::TENSORFLOW_SERVING),
compression_algorithm_(compression_algorithm),
http_headers_(http_headers)
{
}
void ParseInferStat(
const tc::InferStat& tfserve_infer_stat, InferStat* infer_stat);
std::unique_ptr<GrpcClient> grpc_client_;
grpc_compression_algorithm compression_algorithm_;
std::shared_ptr<Headers> http_headers_;
};
//==============================================================
/// TFServeInferRequestedOutput is a wrapper around
/// InferRequestedOutput object of triton common client library.
///
class TFServeInferRequestedOutput : public InferRequestedOutput {
public:
static Error Create(
InferRequestedOutput** infer_output, const std::string& name);
/// Returns the raw InferRequestedOutput object required by TFserving client
/// library.
tc::InferRequestedOutput* Get() const { return output_.get(); }
private:
explicit TFServeInferRequestedOutput(const std::string& name);
std::unique_ptr<tc::InferRequestedOutput> output_;
};
//==============================================================
/// TFServeInferResult is a wrapper around InferResult object of
/// TF serving InferResult object.
///
class TFServeInferResult : public cb::InferResult {
public:
explicit TFServeInferResult(tfs::InferResult* result);
/// See InferResult::Id()
Error Id(std::string* id) const override;
/// See InferResult::RequestStatus()
Error RequestStatus() const override;
/// See InferResult::RawData()
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
private:
std::unique_ptr<tfs::InferResult> result_;
};
}}}} // namespace triton::perfanalyzer::clientbackend::tfserving
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "tfserve_grpc_client.h"
#include <chrono>
#include <cstdint>
#include <fstream>
#include <iostream>
#include <mutex>
#include <sstream>
#include "tfserve_client_backend.h"
/// Type alias for string-TensorProto map.
typedef google::protobuf::Map<std::string, tensorflow::TensorProto>
StringKeyedProtos;
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tfserving {
namespace {
// Use map to keep track of GRPC channels. <key, value> : <url, Channel*>
// If context is created on url that has established Channel, then reuse it.
std::map<std::string, std::shared_ptr<grpc::Channel>> grpc_channel_map_;
std::mutex grpc_channel_map_mtx_;
void
GetTensorFlowDataType(const std::string& datatype, tensorflow::DataType* dtype)
{
if (datatype == "FP16") {
*dtype = tensorflow::DataType::DT_HALF;
} else if (datatype == "BF16") {
*dtype = tensorflow::DataType::DT_BFLOAT16;
} else if (datatype == "FP32") {
*dtype = tensorflow::DataType::DT_FLOAT;
} else if (datatype == "FP64") {
*dtype = tensorflow::DataType::DT_DOUBLE;
} else if (datatype == "INT32") {
*dtype = tensorflow::DataType::DT_INT32;
} else if (datatype == "INT16") {
*dtype = tensorflow::DataType::DT_INT16;
} else if (datatype == "UINT16") {
*dtype = tensorflow::DataType::DT_UINT16;
} else if (datatype == "INT8") {
*dtype = tensorflow::DataType::DT_INT8;
} else if (datatype == "UINT8") {
*dtype = tensorflow::DataType::DT_UINT8;
} else if (datatype == "BYTES") {
*dtype = tensorflow::DataType::DT_STRING;
} else if (datatype == "INT64") {
*dtype = tensorflow::DataType::DT_INT64;
} else if (datatype == "BOOL") {
*dtype = tensorflow::DataType::DT_BOOL;
} else if (datatype == "UINT32") {
*dtype = tensorflow::DataType::DT_UINT32;
} else if (datatype == "UINT64") {
*dtype = tensorflow::DataType::DT_UINT64;
} else {
*dtype = tensorflow::DT_INVALID;
}
}
void
ReadFile(const std::string& filename, std::string& data)
{
data.clear();
if (!filename.empty()) {
std::ifstream file(filename.c_str(), std::ios::in);
if (file.is_open()) {
std::stringstream ss;
ss << file.rdbuf();
file.close();
data = ss.str();
}
}
}
std::shared_ptr<grpc::Channel>
GetChannel(const std::string& url, bool use_ssl, const SslOptions& ssl_options)
{
std::lock_guard<std::mutex> lock(grpc_channel_map_mtx_);
const auto& channel_itr = grpc_channel_map_.find(url);
if (channel_itr != grpc_channel_map_.end()) {
return channel_itr->second;
} else {
grpc::ChannelArguments arguments;
arguments.SetMaxSendMessageSize(tc::MAX_GRPC_MESSAGE_SIZE);
arguments.SetMaxReceiveMessageSize(tc::MAX_GRPC_MESSAGE_SIZE);
std::shared_ptr<grpc::ChannelCredentials> credentials;
if (use_ssl) {
std::string root;
std::string key;
std::string cert;
ReadFile(ssl_options.root_certificates, root);
ReadFile(ssl_options.private_key, key);
ReadFile(ssl_options.certificate_chain, cert);
grpc::SslCredentialsOptions opts = {root, key, cert};
credentials = grpc::SslCredentials(opts);
} else {
credentials = grpc::InsecureChannelCredentials();
}
std::shared_ptr<grpc::Channel> channel =
grpc::CreateCustomChannel(url, credentials, arguments);
grpc_channel_map_.insert(std::make_pair(url, channel));
return channel;
}
}
} // namespace
//==============================================================================
// An GrpcInferRequest represents an inflght inference request on gRPC.
//
class GrpcInferRequest {
public:
GrpcInferRequest(TFServeOnCompleteFn callback = nullptr)
: callback_(callback), grpc_status_(),
grpc_response_(std::make_shared<tensorflow::serving::PredictResponse>())
{
}
tc::RequestTimers& Timer() { return timer_; }
friend GrpcClient;
private:
TFServeOnCompleteFn callback_;
// Variables for GRPC call
grpc::ClientContext grpc_context_;
grpc::Status grpc_status_;
std::shared_ptr<tensorflow::serving::PredictResponse> grpc_response_;
// The timers for infer request.
tc::RequestTimers timer_;
};
//==============================================================================
Error
GrpcClient::Create(
std::unique_ptr<GrpcClient>* client, const std::string& server_url,
bool verbose, bool use_ssl, const SslOptions& ssl_options)
{
client->reset(new GrpcClient(server_url, verbose, use_ssl, ssl_options));
return Error::Success;
}
Error
GrpcClient::ModelMetadata(
tensorflow::serving::GetModelMetadataResponse* model_metadata,
const std::string& model_name, const std::string& model_version,
const Headers& headers)
{
model_metadata->Clear();
Error err;
tensorflow::serving::GetModelMetadataRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.mutable_model_spec()->set_name(model_name);
if (!model_version.empty()) {
request.mutable_model_spec()->set_version_label(model_version);
}
request.add_metadata_field("signature_def");
grpc::Status grpc_status =
stub_->GetModelMetadata(&context, request, model_metadata);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << model_metadata->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
GrpcClient::Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers,
const grpc_compression_algorithm compression_algorithm)
{
Error err;
grpc::ClientContext context;
std::shared_ptr<GrpcInferRequest> sync_request(new GrpcInferRequest());
sync_request->Timer().Reset();
sync_request->Timer().CaptureTimestamp(
tc::RequestTimers::Kind::REQUEST_START);
// Use send timer to measure time for marshalling infer request
sync_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::SEND_START);
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
context.set_compression_algorithm(compression_algorithm);
err = PreRunProcessing(options, inputs, outputs);
sync_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::SEND_END);
if (!err.IsOk()) {
return err;
}
sync_request->grpc_response_->Clear();
sync_request->grpc_status_ = stub_->Predict(
&context, infer_request_, sync_request->grpc_response_.get());
if (!sync_request->grpc_status_.ok()) {
err = Error(sync_request->grpc_status_.error_message());
}
sync_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::RECV_START);
InferResult::Create(result, sync_request->grpc_response_, err);
sync_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::RECV_END);
sync_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::REQUEST_END);
tc::Error update_err = UpdateInferStat(sync_request->Timer());
if (!update_err.IsOk()) {
std::cerr << "Failed to update context stat: " << update_err << std::endl;
}
if (sync_request->grpc_status_.ok()) {
if (verbose_) {
std::cout << sync_request->grpc_response_->DebugString() << std::endl;
}
}
return (*result)->RequestStatus();
}
Error
GrpcClient::AsyncInfer(
TFServeOnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers,
const grpc_compression_algorithm compression_algorithm)
{
if (callback == nullptr) {
return Error(
"Callback function must be provided along with AsyncInfer() call.");
}
if (!worker_.joinable()) {
worker_ = std::thread(&GrpcClient::AsyncTransfer, this);
}
GrpcInferRequest* async_request;
async_request = new GrpcInferRequest(std::move(callback));
async_request->Timer().CaptureTimestamp(
tc::RequestTimers::Kind::REQUEST_START);
async_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::SEND_START);
for (const auto& it : headers) {
async_request->grpc_context_.AddMetadata(it.first, it.second);
}
async_request->grpc_context_.set_compression_algorithm(compression_algorithm);
Error err = PreRunProcessing(options, inputs, outputs);
if (!err.IsOk()) {
delete async_request;
return err;
}
async_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::SEND_END);
std::unique_ptr<
grpc::ClientAsyncResponseReader<tensorflow::serving::PredictResponse>>
rpc(stub_->PrepareAsyncPredict(
&async_request->grpc_context_, infer_request_,
&async_request_completion_queue_));
rpc->StartCall();
rpc->Finish(
async_request->grpc_response_.get(), &async_request->grpc_status_,
(void*)async_request);
if (verbose_) {
std::cout << "Sent request";
if (options.request_id_.size() != 0) {
std::cout << " '" << options.request_id_ << "'";
}
std::cout << std::endl;
}
return Error::Success;
}
void
GrpcClient::AsyncTransfer()
{
while (!exiting_) {
// GRPC async APIs are thread-safe https://github.com/grpc/grpc/issues/4486
GrpcInferRequest* raw_async_request;
bool ok = true;
bool status =
async_request_completion_queue_.Next((void**)(&raw_async_request), &ok);
std::shared_ptr<GrpcInferRequest> async_request;
if (!ok) {
fprintf(stderr, "Unexpected not ok on client side.\n");
}
if (!status) {
if (!exiting_) {
fprintf(stderr, "Completion queue is closed.\n");
}
} else if (raw_async_request == nullptr) {
fprintf(stderr, "Unexpected null tag received at client.\n");
} else {
async_request.reset(raw_async_request);
InferResult* async_result;
Error err;
if (!async_request->grpc_status_.ok()) {
err = Error(async_request->grpc_status_.error_message());
}
async_request->Timer().CaptureTimestamp(
tc::RequestTimers::Kind::RECV_START);
InferResult::Create(&async_result, async_request->grpc_response_, err);
async_request->Timer().CaptureTimestamp(
tc::RequestTimers::Kind::RECV_END);
async_request->Timer().CaptureTimestamp(
tc::RequestTimers::Kind::REQUEST_END);
tc::Error update_err = UpdateInferStat(async_request->Timer());
if (!update_err.IsOk()) {
std::cerr << "Failed to update context stat: " << update_err
<< std::endl;
}
if (async_request->grpc_status_.ok()) {
if (verbose_) {
std::cout << async_request->grpc_response_->DebugString()
<< std::endl;
}
}
async_request->callback_(async_result);
}
}
}
Error
GrpcClient::PreRunProcessing(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
// Populate the request protobuf
// Describing model name and signature from remote server.
infer_request_.mutable_model_spec()->set_name(options.model_name_);
if (!options.model_version_.empty()) {
infer_request_.mutable_model_spec()->set_version_label(
options.model_version_);
}
if (!options.model_signature_name_.empty()) {
infer_request_.mutable_model_spec()->set_signature_name(
options.model_signature_name_);
}
// Describing remote model inputs shape.
StringKeyedProtos& keyed_proto_inputs = *infer_request_.mutable_inputs();
std::set<std::string> request_inputs;
for (const auto input : inputs) {
auto raw_input = dynamic_cast<TFServeInferInput*>(input);
request_inputs.insert(raw_input->Name());
// Add new TensorProto submessages only if required, otherwise
// reuse the submessages already available.
auto itr = keyed_proto_inputs.find(raw_input->Name());
if (itr == keyed_proto_inputs.end()) {
itr = keyed_proto_inputs
.insert(google::protobuf::MapPair<
std::string, tensorflow::TensorProto>(
raw_input->Name(), tensorflow::TensorProto()))
.first;
}
// Set datatype
tensorflow::DataType tf_dtype = tensorflow::DT_INVALID;
GetTensorFlowDataType(raw_input->Datatype(), &tf_dtype);
itr->second.set_dtype(tf_dtype);
if (tf_dtype == tensorflow::DT_INVALID) {
return Error(
"failed to retrieve the TF datatype for " + raw_input->Name());
}
// Populate the shape
itr->second.mutable_tensor_shape()->Clear();
for (const auto dim : raw_input->Shape()) {
itr->second.mutable_tensor_shape()->add_dim()->set_size(dim);
}
raw_input->PrepareForRequest();
// There is an extra copy into the buffer to collect all the input
// batches. This is a room for improvement for later.
bool end_of_input = false;
// auto* raw_contents = itr->second.mutable_float_val()->mutable_data();
size_t content_size;
raw_input->ByteSize(&content_size);
temp_buffer_.clear();
temp_buffer_.reserve(content_size);
while (!end_of_input) {
const uint8_t* buf;
size_t buf_size;
raw_input->GetNext(&buf, &buf_size, &end_of_input);
if (buf != nullptr) {
temp_buffer_.append(reinterpret_cast<const char*>(buf), buf_size);
}
}
ClearAllInputFields(&itr->second);
PopulateInputData(raw_input, &itr->second);
}
// Remove extra tensor protos, if any.
std::set<std::string> extra_inputs;
for (const auto& iter : keyed_proto_inputs) {
if (request_inputs.find(iter.first) == request_inputs.end()) {
extra_inputs.insert(iter.first);
}
}
for (const auto& extra_input : extra_inputs) {
keyed_proto_inputs.erase(extra_input);
}
if (infer_request_.ByteSizeLong() > INT_MAX) {
size_t request_size = infer_request_.ByteSizeLong();
infer_request_.Clear();
return Error(
"Request has byte size " + std::to_string(request_size) +
" which exceed gRPC's byte size limit " + std::to_string(INT_MAX) +
".");
}
return Error::Success;
}
Error
GrpcClient::ClearAllInputFields(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_half_val()->Clear();
input_tensor_proto->mutable_float_val()->Clear();
input_tensor_proto->mutable_double_val()->Clear();
input_tensor_proto->mutable_int_val()->Clear();
input_tensor_proto->mutable_string_val()->Clear();
input_tensor_proto->mutable_int64_val()->Clear();
input_tensor_proto->mutable_bool_val()->Clear();
input_tensor_proto->mutable_uint32_val()->Clear();
input_tensor_proto->mutable_uint64_val()->Clear();
return Error::Success;
}
Error
GrpcClient::PopulateInputData(
TFServeInferInput* input, tensorflow::TensorProto* input_tensor_proto)
{
if (input->Datatype() == "FP16") {
RETURN_IF_CB_ERROR(PopulateHalfVal(input_tensor_proto));
} else if (input->Datatype() == "BF16") {
return Error(
"BF16 datatype not currently supported for populating input data.");
} else if (input->Datatype() == "FP32") {
RETURN_IF_CB_ERROR(PopulateFloatVal(input_tensor_proto));
} else if (input->Datatype() == "FP64") {
RETURN_IF_CB_ERROR(PopulateDoubleVal(input_tensor_proto));
} else if (input->Datatype() == "INT32") {
RETURN_IF_CB_ERROR(PopulateIntVal(input_tensor_proto));
} else if (input->Datatype() == "INT16") {
RETURN_IF_CB_ERROR(PopulateIntVal(input_tensor_proto, 2));
} else if (input->Datatype() == "UINT16") {
RETURN_IF_CB_ERROR(PopulateIntVal(input_tensor_proto, 2));
} else if (input->Datatype() == "INT8") {
RETURN_IF_CB_ERROR(PopulateIntVal(input_tensor_proto, 1));
} else if (input->Datatype() == "UINT8") {
RETURN_IF_CB_ERROR(PopulateIntVal(input_tensor_proto, 1));
} else if (input->Datatype() == "BYTES") {
RETURN_IF_CB_ERROR(PopulateStrVal(input_tensor_proto));
} else if (input->Datatype() == "INT64") {
RETURN_IF_CB_ERROR(PopulateInt64Val(input_tensor_proto));
} else if (input->Datatype() == "BOOL") {
RETURN_IF_CB_ERROR(PopulateBoolVal(input_tensor_proto));
} else if (input->Datatype() == "UINT32") {
RETURN_IF_CB_ERROR(PopulateUintVal(input_tensor_proto));
} else if (input->Datatype() == "UINT64") {
RETURN_IF_CB_ERROR(PopulateUint64Val(input_tensor_proto));
} else {
return Error("unsupported datatype for populating input data");
}
return Error::Success;
}
Error
GrpcClient::PopulateHalfVal(tensorflow::TensorProto* input_tensor_proto)
{
// Building FP16 one by one. Note that since protobuf has no int16 type, we'll
// have some pointless zero padding for each value here.
input_tensor_proto->mutable_half_val()->Reserve(2 * temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
int32_t elem;
memcpy(&elem, (temp_buffer_.c_str() + copied_byte_size), 2);
input_tensor_proto->add_half_val(elem);
copied_byte_size += 2;
}
return Error::Success;
}
Error
GrpcClient::PopulateFloatVal(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_float_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
input_tensor_proto->add_float_val(
*(float*)(temp_buffer_.c_str() + copied_byte_size));
copied_byte_size += sizeof(float);
}
return Error::Success;
}
Error
GrpcClient::PopulateDoubleVal(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_double_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
input_tensor_proto->add_double_val(
*(double*)(temp_buffer_.c_str() + copied_byte_size));
copied_byte_size += sizeof(double);
}
return Error::Success;
}
Error
GrpcClient::PopulateIntVal(
tensorflow::TensorProto* input_tensor_proto, size_t step_size)
{
if (step_size == 4) {
input_tensor_proto->mutable_int_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
input_tensor_proto->add_int_val(
*(int*)(temp_buffer_.c_str() + copied_byte_size));
copied_byte_size += sizeof(int);
}
} else {
// Note that since protobuf has no int16/int8 type, we'll
// have some pointless zero padding for each value here and
// need to build the tensor one element at a time
input_tensor_proto->mutable_int_val()->Reserve(
temp_buffer_.size() * (4 / step_size));
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
int32_t elem;
memcpy(&elem, (temp_buffer_.c_str() + copied_byte_size), step_size);
input_tensor_proto->add_int_val(elem);
copied_byte_size += step_size;
}
}
return Error::Success;
}
Error
GrpcClient::PopulateStrVal(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_string_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
int32_t string_length = *((int*)(temp_buffer_.c_str() + copied_byte_size));
input_tensor_proto->add_string_val(std::string(
(temp_buffer_.c_str() + copied_byte_size + 4), string_length));
copied_byte_size += (string_length + 4);
}
return Error::Success;
}
Error
GrpcClient::PopulateBoolVal(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_bool_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
input_tensor_proto->add_bool_val(
*(bool*)(temp_buffer_.c_str() + copied_byte_size));
copied_byte_size += sizeof(bool);
}
return Error::Success;
}
Error
GrpcClient::PopulateInt64Val(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_int64_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
input_tensor_proto->add_bool_val(
*(int64_t*)(temp_buffer_.c_str() + copied_byte_size));
copied_byte_size += sizeof(int64_t);
}
return Error::Success;
}
Error
GrpcClient::PopulateUintVal(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_uint32_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
input_tensor_proto->add_uint32_val(
*(uint32_t*)(temp_buffer_.c_str() + copied_byte_size));
copied_byte_size += sizeof(uint32_t);
}
return Error::Success;
}
Error
GrpcClient::PopulateUint64Val(tensorflow::TensorProto* input_tensor_proto)
{
input_tensor_proto->mutable_uint64_val()->Reserve(temp_buffer_.size());
uint64_t copied_byte_size = 0;
while (copied_byte_size < temp_buffer_.size()) {
input_tensor_proto->add_uint64_val(
*(uint64_t*)(temp_buffer_.c_str() + copied_byte_size));
copied_byte_size += sizeof(uint64_t);
}
return Error::Success;
}
GrpcClient::GrpcClient(
const std::string& url, bool verbose, bool use_ssl,
const SslOptions& ssl_options)
: InferenceServerClient(verbose),
stub_(tensorflow::serving::PredictionService::NewStub(
GetChannel(url, use_ssl, ssl_options)))
{
}
GrpcClient::~GrpcClient()
{
exiting_ = true;
// Close complete queue and wait for the worker thread to return
async_request_completion_queue_.Shutdown();
// thread not joinable if AsyncInfer() is not called
// (it is default constructed thread before the first AsyncInfer() call)
if (worker_.joinable()) {
worker_.join();
}
bool has_next = true;
GrpcInferRequest* async_request;
bool ok;
do {
has_next =
async_request_completion_queue_.Next((void**)&async_request, &ok);
if (has_next && async_request != nullptr) {
delete async_request;
}
} while (has_next);
}
//======================================================================
Error
InferResult::Create(
InferResult** infer_result,
std::shared_ptr<tensorflow::serving::PredictResponse> response,
Error& request_status)
{
*infer_result =
reinterpret_cast<InferResult*>(new InferResult(response, request_status));
return Error::Success;
}
Error
InferResult::RequestStatus() const
{
return request_status_;
}
InferResult::InferResult(
std::shared_ptr<tensorflow::serving::PredictResponse> response,
Error& request_status)
: response_(response), request_status_(request_status)
{
}
//======================================================================
}}}} // namespace triton::perfanalyzer::clientbackend::tfserving
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <grpc++/grpc++.h>
#include "../client_backend.h"
#include "common.h"
#include "tensorflow_serving/apis/prediction_service.grpc.pb.h"
#include "tfserve_infer_input.h"
namespace tc = triton::client;
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tfserving {
struct SslOptions {
explicit SslOptions() {}
// File containing the PEM encoding of the server root certificates.
// If this parameter is empty, the default roots will be used. The
// default roots can be overridden using the
// GRPC_DEFAULT_SSL_ROOTS_FILE_PATH environment variable pointing
// to a file on the file system containing the roots.
std::string root_certificates;
// File containing the PEM encoding of the client's private key.
// This parameter can be empty if the client does not have a
// private key.
std::string private_key;
// File containing the PEM encoding of the client's certificate chain.
// This parameter can be empty if the client does not have a
// certificate chain.
std::string certificate_chain;
};
class InferResult;
using TFServeOnCompleteFn = std::function<void(InferResult*)>;
//==============================================================================
/// An GrpcClient object is used to perform any kind of communication with the
/// TFserving service using gRPC protocol. None of the functions are thread
/// safe.
///
/// \code
/// std::unique_ptr<GrpcClient> client;
/// GrpcClient::Create(&client, "localhost:8500");
/// ...
/// ...
/// \endcode
///
class GrpcClient : public tc::InferenceServerClient {
public:
~GrpcClient();
/// Create a client that can be used to communicate with the server.
/// \param client Returns a new InferenceServerGrpcClient object.
/// \param server_url The inference server name and port.
/// \param verbose If true generate verbose output when contacting
/// the inference server.
/// \param use_ssl If true use encrypted channel to the server.
/// \param ssl_options Specifies the files required for
/// SSL encryption and authorization.
/// \return Error object indicating success or failure.
static Error Create(
std::unique_ptr<GrpcClient>* client, const std::string& server_url,
bool verbose = false, bool use_ssl = false,
const SslOptions& ssl_options = SslOptions());
/// Contact the inference server and get the metadata of specified model.
/// \param model_metadata Returns model metadata as ModelMetadataResponse
/// message.
/// \param model_name The name of the model to get metadata.
/// \param model_version The version of the model to get metadata.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error ModelMetadata(
tensorflow::serving::GetModelMetadataResponse* model_metadata,
const std::string& model_name, const std::string& model_version = "",
const Headers& headers = Headers());
/// Run synchronous inference on server.
/// \param result Returns the result of inference.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs Optional vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the model
/// config will be returned as default settings.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \param compression_algorithm The compression algorithm to be used
/// on the grpc requests.
/// \return Error object indicating success or failure of the
/// request.
Error Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>(),
const Headers& headers = Headers(),
const grpc_compression_algorithm compression_algorithm =
GRPC_COMPRESS_NONE);
/// Run asynchronous inference on server.
/// Once the request is completed, the InferResult pointer will be passed to
/// the provided 'callback' function. Upon the invocation of callback
/// function, the ownership of InferResult object is transferred to the
/// function caller. It is then the caller's choice on either retrieving the
/// results inside the callback function or deferring it to a different thread
/// so that the client is unblocked. In order to prevent memory leak, user
/// must ensure this object gets deleted.
/// \param callback The callback function to be invoked on request completion.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs Optional vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the model
/// config will be returned as default settings.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \param compression_algorithm The compression algorithm to be used
/// on the grpc requests.
/// \return Error object indicating success or failure of the request.
Error AsyncInfer(
TFServeOnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>(),
const Headers& headers = Headers(),
const grpc_compression_algorithm compression_algorithm =
GRPC_COMPRESS_NONE);
private:
GrpcClient(
const std::string& url, bool verbose, bool use_ssl,
const SslOptions& ssl_options);
Error PreRunProcessing(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs);
void AsyncTransfer();
Error ClearAllInputFields(tensorflow::TensorProto* input_tensor_proto);
Error PopulateInputData(
TFServeInferInput* input, tensorflow::TensorProto* input_tensor_proto);
Error PopulateHalfVal(tensorflow::TensorProto* input_tensor_proto);
Error PopulateFloatVal(tensorflow::TensorProto* input_tensor_proto);
Error PopulateDoubleVal(tensorflow::TensorProto* input_tensor_proto);
Error PopulateIntVal(
tensorflow::TensorProto* input_tensor_proto, size_t step_size = 4);
Error PopulateStrVal(tensorflow::TensorProto* input_tensor_proto);
Error PopulateBoolVal(tensorflow::TensorProto* input_tensor_proto);
Error PopulateInt64Val(tensorflow::TensorProto* input_tensor_proto);
Error PopulateUintVal(tensorflow::TensorProto* input_tensor_proto);
Error PopulateUint64Val(tensorflow::TensorProto* input_tensor_proto);
// The producer-consumer queue used to communicate asynchronously with
// the GRPC runtime.
grpc::CompletionQueue async_request_completion_queue_;
bool enable_stream_stats_;
std::mutex stream_mutex_;
// GRPC end point.
std::unique_ptr<tensorflow::serving::PredictionService::Stub> stub_;
// request for GRPC call, one request object can be used for multiple calls
// since it can be overwritten as soon as the GRPC send finishes.
tensorflow::serving::PredictRequest infer_request_;
// A temporary buffer to hold serialized data
std::string temp_buffer_;
};
//======================================================================
class InferResult {
public:
static Error Create(
InferResult** infer_result,
std::shared_ptr<tensorflow::serving::PredictResponse> response,
Error& request_status);
Error RequestStatus() const;
Error Id(std::string* id) const;
std::string DebugString() const { return response_->DebugString(); }
private:
InferResult(
std::shared_ptr<tensorflow::serving::PredictResponse> response,
Error& request_status);
std::shared_ptr<tensorflow::serving::PredictResponse> response_;
Error request_status_;
};
//======================================================================
}}}} // namespace triton::perfanalyzer::clientbackend::tfserving
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "tfserve_infer_input.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tfserving {
Error
TFServeInferInput::Create(
InferInput** infer_input, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype)
{
TFServeInferInput* local_infer_input =
new TFServeInferInput(name, dims, datatype);
*infer_input = local_infer_input;
return Error::Success;
}
Error
TFServeInferInput::SetShape(const std::vector<int64_t>& shape)
{
shape_ = shape;
return Error::Success;
}
Error
TFServeInferInput::Reset()
{
bufs_.clear();
buf_byte_sizes_.clear();
bufs_idx_ = 0;
byte_size_ = 0;
return Error::Success;
}
Error
TFServeInferInput::AppendRaw(const uint8_t* input, size_t input_byte_size)
{
byte_size_ += input_byte_size;
bufs_.push_back(input);
buf_byte_sizes_.push_back(input_byte_size);
return Error::Success;
}
Error
TFServeInferInput::ByteSize(size_t* byte_size) const
{
*byte_size = byte_size_;
return Error::Success;
}
Error
TFServeInferInput::PrepareForRequest()
{
// Reset position so request sends entire input.
bufs_idx_ = 0;
buf_pos_ = 0;
return Error::Success;
}
Error
TFServeInferInput::GetNext(
const uint8_t** buf, size_t* input_bytes, bool* end_of_input)
{
if (bufs_idx_ < bufs_.size()) {
*buf = bufs_[bufs_idx_];
*input_bytes = buf_byte_sizes_[bufs_idx_];
bufs_idx_++;
} else {
*buf = nullptr;
*input_bytes = 0;
}
*end_of_input = (bufs_idx_ >= bufs_.size());
return Error::Success;
}
TFServeInferInput::TFServeInferInput(
const std::string& name, const std::vector<int64_t>& dims,
const std::string& datatype)
: InferInput(BackendKind::TENSORFLOW_SERVING, name, datatype), shape_(dims)
{
}
}}}} // namespace triton::perfanalyzer::clientbackend::tfserving
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include "../../perf_utils.h"
#include "../client_backend.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tfserving {
//==============================================================
/// TFServeInferInput instance holds the information regarding
/// model input tensors and their corresponding generated data.
///
class TFServeInferInput : public InferInput {
public:
static Error Create(
InferInput** infer_input, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype);
/// See InferInput::Shape()
const std::vector<int64_t>& Shape() const override { return shape_; }
/// See InferInput::SetShape()
Error SetShape(const std::vector<int64_t>& shape) override;
/// See InferInput::Reset()
Error Reset() override;
/// See InferInput::AppendRaw()
Error AppendRaw(const uint8_t* input, size_t input_byte_size) override;
/// Gets the size of data added into this input in bytes.
/// \param byte_size The size of data added in bytes.
/// \return Error object indicating success or failure.
Error ByteSize(size_t* byte_size) const;
/// Resets the heads to start providing data from the beginning.
Error PrepareForRequest();
/// Get the next chunk of data if available.
Error GetNext(const uint8_t** buf, size_t* input_bytes, bool* end_of_input);
private:
explicit TFServeInferInput(
const std::string& name, const std::vector<int64_t>& dims,
const std::string& datatype);
std::vector<int64_t> shape_;
size_t byte_size_{0};
size_t bufs_idx_, buf_pos_;
std::vector<const uint8_t*> bufs_;
std::vector<size_t> buf_byte_sizes_;
};
}}}} // namespace triton::perfanalyzer::clientbackend::tfserving
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18)
set(
TS_CLIENT_BACKEND_SRCS
torchserve_client_backend.cc
torchserve_infer_input.cc
torchserve_http_client.cc
)
set(
TS_CLIENT_BACKEND_HDRS
torchserve_client_backend.h
torchserve_infer_input.h
torchserve_http_client.h
)
add_library(
ts-client-backend-library EXCLUDE_FROM_ALL OBJECT
${TS_CLIENT_BACKEND_SRCS}
${TS_CLIENT_BACKEND_HDRS}
)
target_link_libraries(
ts-client-backend-library
PUBLIC CURL::libcurl
PUBLIC httpclient_static
)
if(${TRITON_ENABLE_GPU})
target_include_directories(ts-client-backend-library PUBLIC ${CUDA_INCLUDE_DIRS})
target_link_libraries(ts-client-backend-library PRIVATE ${CUDA_LIBRARIES})
endif() # TRITON_ENABLE_GPU
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "torchserve_client_backend.h"
#include "json_utils.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace torchserve {
//==============================================================================
Error
TorchServeClientBackend::Create(
const std::string& url, const ProtocolType protocol,
std::shared_ptr<Headers> http_headers, const bool verbose,
std::unique_ptr<ClientBackend>* client_backend)
{
if (protocol == ProtocolType::GRPC) {
return Error(
"perf_analyzer does not support gRPC protocol with TorchServe");
}
std::unique_ptr<TorchServeClientBackend> torchserve_client_backend(
new TorchServeClientBackend(http_headers));
RETURN_IF_CB_ERROR(ts::HttpClient::Create(
&(torchserve_client_backend->http_client_), url, verbose));
*client_backend = std::move(torchserve_client_backend);
return Error::Success;
}
Error
TorchServeClientBackend::Infer(
cb::InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
ts::InferResult* torchserve_result;
RETURN_IF_CB_ERROR(http_client_->Infer(
&torchserve_result, options, inputs, outputs, *http_headers_));
*result = new TorchServeInferResult(torchserve_result);
return Error::Success;
}
Error
TorchServeClientBackend::ClientInferStat(InferStat* infer_stat)
{
// Reusing the common library utilities to collect and report the
// client side statistics.
tc::InferStat client_infer_stat;
RETURN_IF_TRITON_ERROR(http_client_->ClientInferStat(&client_infer_stat));
ParseInferStat(client_infer_stat, infer_stat);
return Error::Success;
}
void
TorchServeClientBackend::ParseInferStat(
const tc::InferStat& torchserve_infer_stat, InferStat* infer_stat)
{
infer_stat->completed_request_count =
torchserve_infer_stat.completed_request_count;
infer_stat->cumulative_total_request_time_ns =
torchserve_infer_stat.cumulative_total_request_time_ns;
infer_stat->cumulative_send_time_ns =
torchserve_infer_stat.cumulative_send_time_ns;
infer_stat->cumulative_receive_time_ns =
torchserve_infer_stat.cumulative_receive_time_ns;
}
//==============================================================================
TorchServeInferResult::TorchServeInferResult(ts::InferResult* result)
{
result_.reset(result);
}
Error
TorchServeInferResult::Id(std::string* id) const
{
id->clear();
return Error::Success;
}
Error
TorchServeInferResult::RequestStatus() const
{
RETURN_IF_CB_ERROR(result_->RequestStatus());
return Error::Success;
}
Error
TorchServeInferResult::RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const
{
return Error(
"Output retrieval is not currently supported for TorchServe client "
"backend");
}
//==============================================================================
}}}} // namespace triton::perfanalyzer::clientbackend::torchserve
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include "../../perf_utils.h"
#include "../client_backend.h"
#include "torchserve_http_client.h"
#define RETURN_IF_TRITON_ERROR(S) \
do { \
const tc::Error& status__ = (S); \
if (!status__.IsOk()) { \
return Error(status__.Message()); \
} \
} while (false)
namespace tc = triton::client;
namespace cb = triton::perfanalyzer::clientbackend;
namespace ts = triton::perfanalyzer::clientbackend::torchserve;
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace torchserve {
//==============================================================================
/// TorchServeClientBackend is used to generate load on the Torchserve instance
///
class TorchServeClientBackend : public ClientBackend {
public:
/// Create a torchserve client backend which can be used to interact with the
/// server.
/// \param url The inference server url and port.
/// \param protocol The protocol type used.
/// \param http_headers Map of HTTP headers. The map key/value indicates
/// the header name/value.
/// \param verbose Enables the verbose mode.
/// \param client_backend Returns a new TorchServeClientBackend
/// object.
/// \return Error object indicating success or failure.
static Error Create(
const std::string& url, const ProtocolType protocol,
std::shared_ptr<Headers> http_headers, const bool verbose,
std::unique_ptr<ClientBackend>* client_backend);
/// See ClientBackend::Infer()
Error Infer(
cb::InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override;
/// See ClientBackend::ClientInferStat()
Error ClientInferStat(InferStat* infer_stat) override;
private:
TorchServeClientBackend(std::shared_ptr<Headers> http_headers)
: ClientBackend(BackendKind::TORCHSERVE), http_headers_(http_headers)
{
}
void ParseInferStat(
const tc::InferStat& torchserve_infer_stat, InferStat* infer_stat);
std::unique_ptr<ts::HttpClient> http_client_;
std::shared_ptr<Headers> http_headers_;
};
//==============================================================
/// TorchServeInferResult is a wrapper around InferResult object of
/// torchserve InferResult object.
///
class TorchServeInferResult : public cb::InferResult {
public:
explicit TorchServeInferResult(ts::InferResult* result);
/// See InferResult::Id()
Error Id(std::string* id) const override;
/// See InferResult::RequestStatus()
Error RequestStatus() const override;
/// See InferResult::RawData()
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
private:
std::unique_ptr<ts::InferResult> result_;
};
}}}} // namespace triton::perfanalyzer::clientbackend::torchserve
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "torchserve_http_client.h"
#include <chrono>
#include <cstdint>
#include "torchserve_client_backend.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace torchserve {
namespace {
constexpr char kContentLengthHTTPHeader[] = "Content-Length";
//==============================================================================
// Global initialization for libcurl. Libcurl requires global
// initialization before any other threads are created and before any
// curl methods are used. The curl_global static object is used to
// perform this initialization.
class CurlGlobal {
public:
CurlGlobal();
~CurlGlobal();
const Error& Status() const { return err_; }
private:
Error err_;
};
CurlGlobal::CurlGlobal() : err_(Error::Success)
{
if (curl_global_init(CURL_GLOBAL_ALL) != 0) {
err_ = Error("global initialization failed");
}
}
CurlGlobal::~CurlGlobal()
{
curl_global_cleanup();
}
static CurlGlobal curl_global;
} // namespace
//==============================================================================
HttpInferRequest::HttpInferRequest()
: header_list_(nullptr),
file_ptr_(std::unique_ptr<FILE, Deleter>(nullptr, Deleter()))
{
}
HttpInferRequest::~HttpInferRequest()
{
if (header_list_ != nullptr) {
curl_slist_free_all(static_cast<curl_slist*>(header_list_));
header_list_ = nullptr;
}
}
Error
HttpInferRequest::InitializeRequest()
{
http_code_ = 400;
// Prepare buffer to record the response
infer_response_buffer_.reset(new std::string());
return Error::Success;
}
Error
HttpInferRequest::OpenFileData(std::string& file_path)
{
FILE* pFile = fopen(file_path.c_str(), "rb");
if (pFile == nullptr) {
return Error("Failed to open the specified file `" + file_path + "`");
}
file_ptr_.reset(pFile);
return Error::Success;
}
long
HttpInferRequest::FileSize()
{
long size;
fseek(file_ptr_.get(), 0, SEEK_END);
size = ftell(file_ptr_.get());
rewind(file_ptr_.get());
return size;
}
Error
HttpInferRequest::CloseFileData()
{
file_ptr_.reset(nullptr);
return Error::Success;
}
//==============================================================================
Error
HttpClient::Create(
std::unique_ptr<HttpClient>* client, const std::string& server_url,
bool verbose)
{
client->reset(new HttpClient(server_url, verbose));
return Error::Success;
}
Error
HttpClient::Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers)
{
Error err;
std::string request_uri(url_ + "/predictions/" + options.model_name_);
if (!options.model_version_.empty()) {
request_uri += "/" + options.model_version_;
}
std::shared_ptr<HttpInferRequest> sync_request(new HttpInferRequest());
sync_request->Timer().Reset();
sync_request->Timer().CaptureTimestamp(
tc::RequestTimers::Kind::REQUEST_START);
if (!curl_global.Status().IsOk()) {
return curl_global.Status();
}
err = PreRunProcessing(
easy_handle_, request_uri, options, inputs, outputs, headers,
sync_request);
if (!err.IsOk()) {
return err;
}
sync_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::SEND_START);
// During this call SEND_END (except in above case), RECV_START, and
// RECV_END will be set.
auto curl_status = curl_easy_perform(easy_handle_);
if (curl_status != CURLE_OK) {
sync_request->http_code_ = 400;
} else {
curl_easy_getinfo(
easy_handle_, CURLINFO_RESPONSE_CODE, &sync_request->http_code_);
}
sync_request->CloseFileData();
curl_mime_free(mime_handle_);
InferResult::Create(result, sync_request);
sync_request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::REQUEST_END);
tc::Error nic_err = UpdateInferStat(sync_request->Timer());
if (!nic_err.IsOk()) {
std::cerr << "Failed to update context stat: " << nic_err << std::endl;
}
err = (*result)->RequestStatus();
return err;
}
size_t
HttpClient::ReadCallback(char* buffer, size_t size, size_t nitems, void* userp)
{
size_t retcode =
fread(buffer, size, nitems, ((HttpInferRequest*)userp)->FilePtr());
if (retcode == 0) {
((HttpInferRequest*)userp)
->Timer()
.CaptureTimestamp(tc::RequestTimers::Kind::SEND_END);
}
return retcode;
}
int
HttpClient::SeekCallback(void* userp, curl_off_t offset, int origin)
{
if (fseek(((HttpInferRequest*)userp)->FilePtr(), offset, origin) == 0)
return CURL_SEEKFUNC_OK;
else
return CURL_SEEKFUNC_FAIL;
}
size_t
HttpClient::InferResponseHeaderHandler(
void* contents, size_t size, size_t nmemb, void* userp)
{
HttpInferRequest* request = reinterpret_cast<HttpInferRequest*>(userp);
char* buf = reinterpret_cast<char*>(contents);
size_t byte_size = size * nmemb;
size_t idx = strlen(kContentLengthHTTPHeader);
if ((idx < byte_size) && !strncasecmp(buf, kContentLengthHTTPHeader, idx)) {
while ((idx < byte_size) && (buf[idx] != ':')) {
++idx;
}
if (idx < byte_size) {
std::string hdr(buf + idx + 1, byte_size - idx - 1);
request->infer_response_buffer_->reserve(std::stoi(hdr));
}
}
return byte_size;
}
size_t
HttpClient::InferResponseHandler(
void* contents, size_t size, size_t nmemb, void* userp)
{
HttpInferRequest* request = reinterpret_cast<HttpInferRequest*>(userp);
if (request->Timer().Timestamp(tc::RequestTimers::Kind::RECV_START) == 0) {
request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::RECV_START);
}
char* buf = reinterpret_cast<char*>(contents);
size_t result_bytes = size * nmemb;
request->infer_response_buffer_->append(buf, result_bytes);
// InferResponseHandler may be called multiple times so we overwrite
// RECV_END so that we always have the time of the last.
request->Timer().CaptureTimestamp(tc::RequestTimers::Kind::RECV_END);
return result_bytes;
}
Error
HttpClient::PreRunProcessing(
void* vcurl, std::string& request_uri, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, std::shared_ptr<HttpInferRequest>& http_request)
{
CURL* curl = reinterpret_cast<CURL*>(vcurl);
// Prepare the request object to provide the data for inference.
Error err = http_request->InitializeRequest();
if (!err.IsOk()) {
return err;
}
std::vector<std::string> input_filepaths;
curl_easy_setopt(curl, CURLOPT_URL, request_uri.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "libcurl-agent/1.0");
curl_easy_setopt(curl, CURLOPT_TCP_NODELAY, 1L);
if (verbose_) {
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);
}
const long buffer_byte_size = 16 * 1024 * 1024;
curl_easy_setopt(curl, CURLOPT_UPLOAD_BUFFERSIZE, buffer_byte_size);
curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, buffer_byte_size);
// request data provided by InferRequestProvider()
mime_handle_ = curl_mime_init(easy_handle_);
// Add the buffers holding input tensor data
for (const auto input : inputs) {
TorchServeInferInput* this_input =
dynamic_cast<TorchServeInferInput*>(input);
this_input->PrepareForRequest();
bool end_of_input = false;
while (!end_of_input) {
const uint8_t* buf;
size_t buf_size;
this_input->GetNext(&buf, &buf_size, &end_of_input);
std::string file_path(
reinterpret_cast<const char*>(buf) + 4, buf_size - 4);
if (buf != nullptr) {
Error err = http_request->OpenFileData(file_path);
if (!err.IsOk()) {
return err;
}
if (verbose_) {
input_filepaths.push_back(file_path);
}
}
}
}
long file_size = http_request->FileSize();
curl_mimepart* part = curl_mime_addpart((curl_mime*)mime_handle_);
curl_mime_data_cb(
part, file_size, ReadCallback, SeekCallback, NULL, http_request.get());
curl_mime_name(part, "data");
curl_easy_setopt(easy_handle_, CURLOPT_MIMEPOST, (curl_mime*)mime_handle_);
// response headers handled by InferResponseHeaderHandler()
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, InferResponseHeaderHandler);
curl_easy_setopt(curl, CURLOPT_HEADERDATA, http_request.get());
// response data handled by InferResponseHandler()
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, InferResponseHandler);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, http_request.get());
struct curl_slist* list = nullptr;
for (const auto& pr : headers) {
std::string hdr = pr.first + ": " + pr.second;
list = curl_slist_append(list, hdr.c_str());
}
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, list);
// The list will be freed when the request is destructed
http_request->header_list_ = list;
if (verbose_) {
std::cout << "inference request : [";
bool first = true;
for (const auto& fn : input_filepaths) {
if (first) {
first = false;
} else {
std::cout << ",";
}
std::cout << "\"" << fn << "\"";
}
std::cout << "]" << std::endl;
}
return Error::Success;
}
HttpClient::HttpClient(const std::string& url, bool verbose)
: InferenceServerClient(verbose), url_(url),
easy_handle_(reinterpret_cast<void*>(curl_easy_init()))
{
}
HttpClient::~HttpClient()
{
exiting_ = true;
if (easy_handle_ != nullptr) {
curl_easy_cleanup(reinterpret_cast<CURL*>(easy_handle_));
}
}
//======================================================================
Error
InferResult::Create(
InferResult** infer_result, std::shared_ptr<HttpInferRequest> infer_request)
{
*infer_result =
reinterpret_cast<InferResult*>(new InferResult(infer_request));
return Error::Success;
}
Error
InferResult::RequestStatus() const
{
return status_;
}
InferResult::InferResult(std::shared_ptr<HttpInferRequest> infer_request)
: infer_request_(infer_request)
{
if (infer_request->http_code_ != 200) {
status_ = Error(
"inference failed with error code " +
std::to_string(infer_request->http_code_));
}
}
//======================================================================
}}}} // namespace triton::perfanalyzer::clientbackend::torchserve
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <curl/curl.h>
#include <stdio.h>
#include <stdlib.h>
#include "../client_backend.h"
#include "common.h"
#include "torchserve_infer_input.h"
namespace tc = triton::client;
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace torchserve {
class InferResult;
class HttpInferRequest;
using TorchServeOnCompleteFn = std::function<void(InferResult*)>;
//==============================================================================
/// An HttpClient object is used to perform any kind of communication with the
/// torchserve service using libcurl. None of the functions are thread
/// safe.
///
/// \code
/// std::unique_ptr<HttpClient> client;
/// HttpClient::Create(&client, "localhost:8080");
/// ...
/// ...
/// \endcode
///
class HttpClient : public tc::InferenceServerClient {
public:
~HttpClient();
/// Create a client that can be used to communicate with the server.
/// \param client Returns a new InferenceServerHttpClient object.
/// \param server_url The inference server name and port.
/// \param verbose If true generate verbose output when contacting
/// the inference server.
/// \return Error object indicating success or failure.
static Error Create(
std::unique_ptr<HttpClient>* client, const std::string& server_url,
const bool verbose);
/// Run synchronous inference on server.
/// \param result Returns the result of inference.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs Optional vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the model
/// config will be returned as default settings.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the
/// request.
Error Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>(),
const Headers& headers = Headers());
private:
HttpClient(const std::string& url, bool verbose);
Error PreRunProcessing(
void* curl, std::string& request_uri, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, std::shared_ptr<HttpInferRequest>& request);
static size_t ReadCallback(
char* buffer, size_t size, size_t nitems, void* userp);
static int SeekCallback(void* userp, curl_off_t offset, int origin);
static size_t InferResponseHeaderHandler(
void* contents, size_t size, size_t nmemb, void* userp);
static size_t InferResponseHandler(
void* contents, size_t size, size_t nmemb, void* userp);
// The server url
const std::string url_;
// curl easy handle shared for all synchronous requests.
void* easy_handle_;
// The handle to interact with mime API.
curl_mime* mime_handle_;
};
//======================================================================
class HttpInferRequest {
public:
struct Deleter {
void operator()(FILE* file)
{
if (file != nullptr) {
fclose(file);
}
}
};
HttpInferRequest();
~HttpInferRequest();
Error InitializeRequest();
Error OpenFileData(std::string& file_path);
long FileSize();
Error CloseFileData();
tc::RequestTimers& Timer() { return timer_; }
std::string& DebugString() { return *infer_response_buffer_; }
FILE* FilePtr() { return file_ptr_.get(); }
friend HttpClient;
friend InferResult;
private:
// Pointer to the list of the HTTP request header, keep it such that it will
// be valid during the transfer and can be freed once transfer is completed.
struct curl_slist* header_list_;
std::unique_ptr<FILE, Deleter> file_ptr_;
// HTTP response code for the inference request
long http_code_;
// Buffer that accumulates the response body.
std::unique_ptr<std::string> infer_response_buffer_;
// The timers for infer request.
tc::RequestTimers timer_;
};
//======================================================================
class InferResult {
public:
static Error Create(
InferResult** infer_result,
std::shared_ptr<HttpInferRequest> infer_request);
Error RequestStatus() const;
Error Id(std::string* id) const;
std::string DebugString() const { return infer_request_->DebugString(); }
private:
InferResult(std::shared_ptr<HttpInferRequest> infer_request);
// The status of the inference
Error status_;
// The pointer to the HttpInferRequest object
std::shared_ptr<HttpInferRequest> infer_request_;
};
//======================================================================
}}}} // namespace triton::perfanalyzer::clientbackend::torchserve
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "torchserve_infer_input.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace torchserve {
Error
TorchServeInferInput::Create(
InferInput** infer_input, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype)
{
TorchServeInferInput* local_infer_input =
new TorchServeInferInput(name, dims, datatype);
*infer_input = local_infer_input;
return Error::Success;
}
Error
TorchServeInferInput::SetShape(const std::vector<int64_t>& shape)
{
shape_ = shape;
return Error::Success;
}
Error
TorchServeInferInput::Reset()
{
bufs_.clear();
buf_byte_sizes_.clear();
bufs_idx_ = 0;
byte_size_ = 0;
return Error::Success;
}
Error
TorchServeInferInput::AppendRaw(const uint8_t* input, size_t input_byte_size)
{
byte_size_ += input_byte_size;
bufs_.push_back(input);
buf_byte_sizes_.push_back(input_byte_size);
return Error::Success;
}
Error
TorchServeInferInput::ByteSize(size_t* byte_size) const
{
*byte_size = byte_size_;
return Error::Success;
}
Error
TorchServeInferInput::PrepareForRequest()
{
// Reset position so request sends entire input.
bufs_idx_ = 0;
buf_pos_ = 0;
return Error::Success;
}
Error
TorchServeInferInput::GetNext(
const uint8_t** buf, size_t* input_bytes, bool* end_of_input)
{
if (bufs_idx_ < bufs_.size()) {
*buf = bufs_[bufs_idx_];
*input_bytes = buf_byte_sizes_[bufs_idx_];
bufs_idx_++;
} else {
*buf = nullptr;
*input_bytes = 0;
}
*end_of_input = (bufs_idx_ >= bufs_.size());
return Error::Success;
}
TorchServeInferInput::TorchServeInferInput(
const std::string& name, const std::vector<int64_t>& dims,
const std::string& datatype)
: InferInput(BackendKind::TORCHSERVE, name, datatype), shape_(dims)
{
}
}}}} // namespace triton::perfanalyzer::clientbackend::torchserve
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <string>
#include "../../perf_utils.h"
#include "../client_backend.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace torchserve {
//==============================================================
/// TorchServeInferInput instance holds the information regarding
/// model input tensor. In this case the content held will be
/// the path to the file holding data.
///
class TorchServeInferInput : public InferInput {
public:
static Error Create(
InferInput** infer_input, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype);
/// See InferInput::Shape()
const std::vector<int64_t>& Shape() const override { return shape_; }
/// See InferInput::SetShape()
Error SetShape(const std::vector<int64_t>& shape) override;
/// See InferInput::Reset()
Error Reset() override;
/// See InferInput::AppendRaw()
Error AppendRaw(const uint8_t* input, size_t input_byte_size) override;
/// Gets the size of data added into this input in bytes.
/// \param byte_size The size of data added in bytes.
/// \return Error object indicating success or failure.
Error ByteSize(size_t* byte_size) const;
/// Resets the heads to start providing data from the beginning.
Error PrepareForRequest();
/// Get the next chunk of data if available.
Error GetNext(const uint8_t** buf, size_t* input_bytes, bool* end_of_input);
private:
explicit TorchServeInferInput(
const std::string& name, const std::vector<int64_t>& dims,
const std::string& datatype);
std::vector<int64_t> shape_;
size_t byte_size_;
size_t bufs_idx_, buf_pos_;
std::vector<const uint8_t*> bufs_;
std::vector<size_t> buf_byte_sizes_;
};
}}}} // namespace triton::perfanalyzer::clientbackend::torchserve
# Copyright 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18)
set(
TRITON_CLIENT_BACKEND_SRCS
triton_client_backend.cc
)
set(
TRITON_CLIENT_BACKEND_HDRS
triton_client_backend.h
)
add_library(
triton-client-backend-library EXCLUDE_FROM_ALL OBJECT
${TRITON_CLIENT_BACKEND_SRCS}
${TRITON_CLIENT_BACKEND_HDRS}
)
target_link_libraries(
triton-client-backend-library
PUBLIC grpcclient_static
PUBLIC httpclient_static
PRIVATE CURL::libcurl
)
target_include_directories(
triton-client-backend-library
PRIVATE CURL::libcurl
)
if(${TRITON_ENABLE_GPU})
target_link_libraries(
triton-client-backend-library
PRIVATE CUDA::cudart
)
endif() # TRITON_ENABLE_GPU
// Copyright 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include <cstdint>
#include <map>
#include <string>
#include "../../doctest.h"
#include "triton_client_backend.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tritonremote {
class TestTritonClientBackend : public TritonClientBackend {
public:
template <typename T>
void ParseAndStoreMetric(
const std::string& metrics_endpoint_text, const std::string metric_id,
std::map<std::string, T>& metric_per_gpu)
{
TritonClientBackend::ParseAndStoreMetric<T>(
metrics_endpoint_text, metric_id, metric_per_gpu);
}
};
TEST_CASE("testing the ParseAndStoreMetric function")
{
TestTritonClientBackend ttcb{};
SUBCASE("nv_gpu_utilization metric")
{
const std::string metrics_endpoint_text{R"(
# HELP nv_gpu_utilization GPU utilization rate [0.0 - 1.0)
# TYPE nv_gpu_utilization gauge
nv_gpu_utilization{gpu_uuid="GPU-00000000-0000-0000-0000-000000000000"} 0.41
nv_gpu_utilization{gpu_uuid="GPU-00000000-0000-0000-0000-000000000001"} 0.77
)"};
const std::string metric_id{"nv_gpu_utilization"};
std::map<std::string, double> gpu_utilization_per_gpu{};
ttcb.ParseAndStoreMetric<double>(
metrics_endpoint_text, metric_id, gpu_utilization_per_gpu);
CHECK(gpu_utilization_per_gpu.size() == 2);
CHECK(
gpu_utilization_per_gpu["GPU-00000000-0000-0000-0000-000000000000"] ==
doctest::Approx(0.41));
CHECK(
gpu_utilization_per_gpu["GPU-00000000-0000-0000-0000-000000000001"] ==
doctest::Approx(0.77));
}
SUBCASE("nv_gpu_power_usage metric")
{
const std::string metrics_endpoint_text{R"(
# HELP nv_gpu_power_usage GPU power usage in watts
# TYPE nv_gpu_power_usage gauge
nv_gpu_power_usage{gpu_uuid="GPU-00000000-0000-0000-0000-000000000000"} 81.619
nv_gpu_power_usage{gpu_uuid="GPU-00000000-0000-0000-0000-000000000001"} 99.217
)"};
const std::string metric_id{"nv_gpu_power_usage"};
std::map<std::string, double> gpu_power_usage_per_gpu{};
ttcb.ParseAndStoreMetric<double>(
metrics_endpoint_text, metric_id, gpu_power_usage_per_gpu);
CHECK(gpu_power_usage_per_gpu.size() == 2);
CHECK(
gpu_power_usage_per_gpu["GPU-00000000-0000-0000-0000-000000000000"] ==
doctest::Approx(81.619));
CHECK(
gpu_power_usage_per_gpu["GPU-00000000-0000-0000-0000-000000000001"] ==
doctest::Approx(99.217));
}
SUBCASE("nv_gpu_memory_used_bytes metric")
{
const std::string metrics_endpoint_text{R"(
# HELP nv_gpu_memory_used_bytes GPU used memory, in bytes
# TYPE nv_gpu_memory_used_bytes gauge
nv_gpu_memory_used_bytes{gpu_uuid="GPU-00000000-0000-0000-0000-000000000000"} 50000000
nv_gpu_memory_used_bytes{gpu_uuid="GPU-00000000-0000-0000-0000-000000000001"} 75000000
)"};
const std::string metric_id{"nv_gpu_memory_used_bytes"};
std::map<std::string, uint64_t> gpu_memory_used_bytes_per_gpu{};
ttcb.ParseAndStoreMetric<uint64_t>(
metrics_endpoint_text, metric_id, gpu_memory_used_bytes_per_gpu);
CHECK(gpu_memory_used_bytes_per_gpu.size() == 2);
CHECK(
gpu_memory_used_bytes_per_gpu
["GPU-00000000-0000-0000-0000-000000000000"] == 50000000);
CHECK(
gpu_memory_used_bytes_per_gpu
["GPU-00000000-0000-0000-0000-000000000001"] == 75000000);
}
SUBCASE("nv_gpu_memory_total_bytes metric")
{
const std::string metrics_endpoint_text{R"(
# HELP nv_gpu_memory_total_bytes GPU total memory, in bytes
# TYPE nv_gpu_memory_total_bytes gauge
nv_gpu_memory_total_bytes{gpu_uuid="GPU-00000000-0000-0000-0000-000000000000"} 1000000000
nv_gpu_memory_total_bytes{gpu_uuid="GPU-00000000-0000-0000-0000-000000000001"} 2000000000
)"};
const std::string metric_id{"nv_gpu_memory_total_bytes"};
std::map<std::string, uint64_t> gpu_memory_total_bytes_per_gpu{};
ttcb.ParseAndStoreMetric<uint64_t>(
metrics_endpoint_text, metric_id, gpu_memory_total_bytes_per_gpu);
CHECK(gpu_memory_total_bytes_per_gpu.size() == 2);
CHECK(
gpu_memory_total_bytes_per_gpu
["GPU-00000000-0000-0000-0000-000000000000"] == 1000000000);
CHECK(
gpu_memory_total_bytes_per_gpu
["GPU-00000000-0000-0000-0000-000000000001"] == 2000000000);
}
}
}}}} // namespace triton::perfanalyzer::clientbackend::tritonremote
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "triton_client_backend.h"
#include <curl/curl.h>
#include <regex>
#include <stdexcept>
#include "../../constants.h"
#include "../../perf_analyzer_exception.h"
#include "json_utils.h"
namespace {
triton::client::HttpSslOptions
ParseHttpSslOptions(
const triton::perfanalyzer::clientbackend::SslOptionsBase& ssl_options)
{
triton::client::HttpSslOptions http_ssl_options;
http_ssl_options.verify_peer = ssl_options.ssl_https_verify_peer;
http_ssl_options.verify_host = ssl_options.ssl_https_verify_host;
http_ssl_options.ca_info = ssl_options.ssl_https_ca_certificates_file;
if (ssl_options.ssl_https_client_certificate_type == "PEM") {
http_ssl_options.cert_type =
triton::client::HttpSslOptions::CERTTYPE::CERT_PEM;
} else if (ssl_options.ssl_https_client_certificate_type == "DER") {
http_ssl_options.cert_type =
triton::client::HttpSslOptions::CERTTYPE::CERT_DER;
}
http_ssl_options.cert = ssl_options.ssl_https_client_certificate_file;
if (ssl_options.ssl_https_private_key_type == "PEM") {
http_ssl_options.key_type =
triton::client::HttpSslOptions::KEYTYPE::KEY_PEM;
} else if (ssl_options.ssl_https_private_key_type == "DER") {
http_ssl_options.key_type =
triton::client::HttpSslOptions::KEYTYPE::KEY_DER;
}
http_ssl_options.key = ssl_options.ssl_https_private_key_file;
return http_ssl_options;
}
std::pair<bool, triton::client::SslOptions>
ParseGrpcSslOptions(
const triton::perfanalyzer::clientbackend::SslOptionsBase& ssl_options)
{
bool use_ssl = ssl_options.ssl_grpc_use_ssl;
triton::client::SslOptions grpc_ssl_options;
grpc_ssl_options.root_certificates =
ssl_options.ssl_grpc_root_certifications_file;
grpc_ssl_options.private_key = ssl_options.ssl_grpc_private_key_file;
grpc_ssl_options.certificate_chain =
ssl_options.ssl_grpc_certificate_chain_file;
return std::pair<bool, triton::client::SslOptions>{use_ssl, grpc_ssl_options};
}
} // namespace
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tritonremote {
//==============================================================================
Error
TritonClientBackend::Create(
const std::string& url, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const grpc_compression_algorithm compression_algorithm,
std::shared_ptr<Headers> http_headers, const bool verbose,
const std::string& metrics_url, const TensorFormat input_tensor_format,
const TensorFormat output_tensor_format,
std::unique_ptr<ClientBackend>* client_backend)
{
std::unique_ptr<TritonClientBackend> triton_client_backend(
new TritonClientBackend(
protocol, compression_algorithm, http_headers, metrics_url,
input_tensor_format, output_tensor_format));
if (protocol == ProtocolType::HTTP) {
triton::client::HttpSslOptions http_ssl_options =
ParseHttpSslOptions(ssl_options);
RETURN_IF_TRITON_ERROR(tc::InferenceServerHttpClient::Create(
&(triton_client_backend->client_.http_client_), url, verbose,
http_ssl_options));
if (!trace_options.empty()) {
std::string response;
RETURN_IF_TRITON_ERROR(
triton_client_backend->client_.http_client_->UpdateTraceSettings(
&response, "", trace_options));
}
} else {
std::pair<bool, triton::client::SslOptions> grpc_ssl_options_pair =
ParseGrpcSslOptions(ssl_options);
bool use_ssl = grpc_ssl_options_pair.first;
triton::client::SslOptions grpc_ssl_options = grpc_ssl_options_pair.second;
RETURN_IF_TRITON_ERROR(tc::InferenceServerGrpcClient::Create(
&(triton_client_backend->client_.grpc_client_), url, verbose, use_ssl,
grpc_ssl_options));
if (!trace_options.empty()) {
inference::TraceSettingResponse response;
RETURN_IF_TRITON_ERROR(
triton_client_backend->client_.grpc_client_->UpdateTraceSettings(
&response, "", trace_options));
}
}
*client_backend = std::move(triton_client_backend);
return Error::Success;
}
Error
TritonClientBackend::ServerExtensions(std::set<std::string>* extensions)
{
extensions->clear();
if (protocol_ == ProtocolType::HTTP) {
std::string server_metadata;
FAIL_IF_TRITON_ERR(
client_.http_client_->ServerMetadata(&server_metadata, *http_headers_),
"unable to get server metadata");
rapidjson::Document server_metadata_json;
FAIL_IF_TRITON_ERR(
tc::ParseJson(&server_metadata_json, server_metadata),
"failed to parse server metadata");
for (const auto& extension :
server_metadata_json["extensions"].GetArray()) {
extensions->insert(
std::string(extension.GetString(), extension.GetStringLength()));
}
} else {
inference::ServerMetadataResponse server_metadata;
FAIL_IF_TRITON_ERR(
client_.grpc_client_->ServerMetadata(&server_metadata, *http_headers_),
"unable to get server metadata");
for (const auto& extension : server_metadata.extensions()) {
extensions->insert(extension);
}
}
return Error::Success;
}
Error
TritonClientBackend::ModelMetadata(
rapidjson::Document* model_metadata, const std::string& model_name,
const std::string& model_version)
{
if (protocol_ == ProtocolType::HTTP) {
std::string metadata;
RETURN_IF_TRITON_ERROR(client_.http_client_->ModelMetadata(
&metadata, model_name, model_version, *http_headers_));
RETURN_IF_TRITON_ERROR(tc::ParseJson(model_metadata, metadata));
} else {
inference::ModelMetadataResponse model_metadata_proto;
RETURN_IF_TRITON_ERROR(client_.grpc_client_->ModelMetadata(
&model_metadata_proto, model_name, model_version, *http_headers_));
std::string metadata;
::google::protobuf::util::JsonPrintOptions options;
options.preserve_proto_field_names = true;
options.always_print_primitive_fields = true;
::google::protobuf::util::MessageToJsonString(
model_metadata_proto, &metadata, options);
RETURN_IF_TRITON_ERROR(tc::ParseJson(model_metadata, metadata));
}
return Error::Success;
}
Error
TritonClientBackend::ModelConfig(
rapidjson::Document* model_config, const std::string& model_name,
const std::string& model_version)
{
if (protocol_ == ProtocolType::HTTP) {
std::string config;
RETURN_IF_TRITON_ERROR(client_.http_client_->ModelConfig(
&config, model_name, model_version, *http_headers_));
RETURN_IF_TRITON_ERROR(tc::ParseJson(model_config, config));
} else {
inference::ModelConfigResponse model_config_proto;
RETURN_IF_TRITON_ERROR(client_.grpc_client_->ModelConfig(
&model_config_proto, model_name, model_version, *http_headers_));
std::string config;
::google::protobuf::util::JsonPrintOptions options;
options.preserve_proto_field_names = true;
options.always_print_primitive_fields = true;
::google::protobuf::util::MessageToJsonString(
model_config_proto, &config, options);
rapidjson::Document full_config;
RETURN_IF_TRITON_ERROR(tc::ParseJson(&full_config, config));
model_config->CopyFrom(full_config["config"], model_config->GetAllocator());
}
return Error::Success;
}
Error
TritonClientBackend::Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
std::vector<tc::InferInput*> triton_inputs;
ParseInferInputToTriton(inputs, &triton_inputs);
std::vector<const tc::InferRequestedOutput*> triton_outputs;
ParseInferRequestedOutputToTriton(outputs, &triton_outputs);
tc::InferOptions triton_options(options.model_name_);
ParseInferOptionsToTriton(options, &triton_options);
tc::InferResult* triton_result;
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(client_.grpc_client_->Infer(
&triton_result, triton_options, triton_inputs, triton_outputs,
*http_headers_, compression_algorithm_));
} else {
RETURN_IF_TRITON_ERROR(client_.http_client_->Infer(
&triton_result, triton_options, triton_inputs, triton_outputs,
*http_headers_));
}
*result = new TritonInferResult(triton_result);
return Error::Success;
}
Error
TritonClientBackend::AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
auto wrapped_callback = [callback](tc::InferResult* client_result) {
InferResult* result = new TritonInferResult(client_result);
callback(result);
};
std::vector<tc::InferInput*> triton_inputs;
ParseInferInputToTriton(inputs, &triton_inputs);
std::vector<const tc::InferRequestedOutput*> triton_outputs;
ParseInferRequestedOutputToTriton(outputs, &triton_outputs);
tc::InferOptions triton_options(options.model_name_);
ParseInferOptionsToTriton(options, &triton_options);
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(client_.grpc_client_->AsyncInfer(
wrapped_callback, triton_options, triton_inputs, triton_outputs,
*http_headers_, compression_algorithm_));
} else {
RETURN_IF_TRITON_ERROR(client_.http_client_->AsyncInfer(
wrapped_callback, triton_options, triton_inputs, triton_outputs,
*http_headers_));
}
return Error::Success;
}
Error
TritonClientBackend::StartStream(OnCompleteFn callback, bool enable_stats)
{
auto wrapped_callback = [callback](tc::InferResult* client_result) {
InferResult* result = new TritonInferResult(client_result);
callback(result);
};
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(client_.grpc_client_->StartStream(
wrapped_callback, enable_stats, 0 /* stream_timeout */, *http_headers_,
compression_algorithm_));
} else {
return Error("HTTP does not support starting streams", pa::GENERIC_ERROR);
}
return Error::Success;
}
Error
TritonClientBackend::AsyncStreamInfer(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
std::vector<tc::InferInput*> triton_inputs;
ParseInferInputToTriton(inputs, &triton_inputs);
std::vector<const tc::InferRequestedOutput*> triton_outputs;
ParseInferRequestedOutputToTriton(outputs, &triton_outputs);
tc::InferOptions triton_options(options.model_name_);
ParseInferOptionsToTriton(options, &triton_options);
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(client_.grpc_client_->AsyncStreamInfer(
triton_options, triton_inputs, triton_outputs));
} else {
return Error(
"HTTP does not support streaming inferences", pa::GENERIC_ERROR);
}
return Error::Success;
}
Error
TritonClientBackend::ClientInferStat(InferStat* infer_stat)
{
tc::InferStat triton_infer_stat;
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(
client_.grpc_client_->ClientInferStat(&triton_infer_stat));
} else {
RETURN_IF_TRITON_ERROR(
client_.http_client_->ClientInferStat(&triton_infer_stat));
}
ParseInferStat(triton_infer_stat, infer_stat);
return Error::Success;
}
Error
TritonClientBackend::ModelInferenceStatistics(
std::map<ModelIdentifier, ModelStatistics>* model_stats,
const std::string& model_name, const std::string& model_version)
{
if (protocol_ == ProtocolType::GRPC) {
inference::ModelStatisticsResponse infer_stat;
RETURN_IF_TRITON_ERROR(client_.grpc_client_->ModelInferenceStatistics(
&infer_stat, model_name, model_version, *http_headers_));
ParseStatistics(infer_stat, model_stats);
} else {
std::string infer_stat;
RETURN_IF_TRITON_ERROR(client_.http_client_->ModelInferenceStatistics(
&infer_stat, model_name, model_version, *http_headers_));
rapidjson::Document infer_stat_json;
RETURN_IF_TRITON_ERROR(tc::ParseJson(&infer_stat_json, infer_stat));
ParseStatistics(infer_stat_json, model_stats);
}
return Error::Success;
}
Error
TritonClientBackend::Metrics(triton::perfanalyzer::Metrics& metrics)
{
try {
std::string metrics_endpoint_text{""};
AccessMetricsEndpoint(metrics_endpoint_text);
ParseAndStoreMetrics(metrics_endpoint_text, metrics);
}
catch (const PerfAnalyzerException& e) {
return Error(e.what(), pa::GENERIC_ERROR);
}
return Error::Success;
}
void
TritonClientBackend::AccessMetricsEndpoint(std::string& metrics_endpoint_text)
{
CURL* curl{curl_easy_init()};
if (curl == nullptr) {
throw triton::perfanalyzer::PerfAnalyzerException(
"Error calling curl_easy_init()", triton::perfanalyzer::GENERIC_ERROR);
}
const auto metrics_response_handler{
[](char* ptr, size_t size, size_t nmemb, std::string* userdata) {
userdata->append(ptr, size * nmemb);
return size * nmemb;
}};
curl_easy_setopt(curl, CURLOPT_URL, metrics_url_.c_str());
curl_easy_setopt(
curl, CURLOPT_WRITEFUNCTION,
static_cast<size_t (*)(char*, size_t, size_t, std::string*)>(
metrics_response_handler));
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &metrics_endpoint_text);
CURLcode res{curl_easy_perform(curl)};
if (res != CURLE_OK) {
throw triton::perfanalyzer::PerfAnalyzerException(
"Unable to connect to Metrics endpoint " + metrics_url_,
triton::perfanalyzer::GENERIC_ERROR);
}
long response_code{0};
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &response_code);
if (response_code != 200) {
throw triton::perfanalyzer::PerfAnalyzerException(
"Metrics endpoint curling did not succeed.",
triton::perfanalyzer::GENERIC_ERROR);
}
curl_easy_cleanup(curl);
}
void
TritonClientBackend::ParseAndStoreMetrics(
const std::string& metrics_endpoint_text,
triton::perfanalyzer::Metrics& metrics)
{
ParseAndStoreMetric<double>(
metrics_endpoint_text, "nv_gpu_utilization",
metrics.gpu_utilization_per_gpu);
ParseAndStoreMetric<double>(
metrics_endpoint_text, "nv_gpu_power_usage",
metrics.gpu_power_usage_per_gpu);
ParseAndStoreMetric<uint64_t>(
metrics_endpoint_text, "nv_gpu_memory_used_bytes",
metrics.gpu_memory_used_bytes_per_gpu);
ParseAndStoreMetric<uint64_t>(
metrics_endpoint_text, "nv_gpu_memory_total_bytes",
metrics.gpu_memory_total_bytes_per_gpu);
}
Error
TritonClientBackend::UnregisterAllSharedMemory()
{
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(
client_.grpc_client_->UnregisterSystemSharedMemory("", *http_headers_));
RETURN_IF_TRITON_ERROR(
client_.grpc_client_->UnregisterCudaSharedMemory("", *http_headers_));
} else {
RETURN_IF_TRITON_ERROR(
client_.http_client_->UnregisterSystemSharedMemory("", *http_headers_));
RETURN_IF_TRITON_ERROR(
client_.http_client_->UnregisterCudaSharedMemory("", *http_headers_));
}
return Error::Success;
}
Error
TritonClientBackend::RegisterSystemSharedMemory(
const std::string& name, const std::string& key, const size_t byte_size)
{
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(client_.grpc_client_->RegisterSystemSharedMemory(
name, key, byte_size, 0 /* offset */, *http_headers_));
} else {
RETURN_IF_TRITON_ERROR(client_.http_client_->RegisterSystemSharedMemory(
name, key, byte_size, 0 /* offset */, *http_headers_));
}
return Error::Success;
}
Error
TritonClientBackend::RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& handle,
const size_t byte_size)
{
if (protocol_ == ProtocolType::GRPC) {
RETURN_IF_TRITON_ERROR(client_.grpc_client_->RegisterCudaSharedMemory(
name, handle, 0 /*device id*/, byte_size, *http_headers_));
} else {
RETURN_IF_TRITON_ERROR(client_.http_client_->RegisterCudaSharedMemory(
name, handle, 0 /*device id*/, byte_size, *http_headers_));
}
return Error::Success;
}
//
// Shared Memory Utilities
//
Error
TritonClientBackend::CreateSharedMemoryRegion(
std::string shm_key, size_t byte_size, int* shm_fd)
{
RETURN_IF_TRITON_ERROR(
tc::CreateSharedMemoryRegion(shm_key, byte_size, shm_fd));
return Error::Success;
}
Error
TritonClientBackend::MapSharedMemory(
int shm_fd, size_t offset, size_t byte_size, void** shm_addr)
{
RETURN_IF_TRITON_ERROR(
tc::MapSharedMemory(shm_fd, offset, byte_size, shm_addr));
return Error::Success;
}
Error
TritonClientBackend::CloseSharedMemory(int shm_fd)
{
RETURN_IF_TRITON_ERROR(tc::CloseSharedMemory(shm_fd));
return Error::Success;
}
Error
TritonClientBackend::UnlinkSharedMemoryRegion(std::string shm_key)
{
RETURN_IF_TRITON_ERROR(tc::UnlinkSharedMemoryRegion(shm_key));
return Error::Success;
}
Error
TritonClientBackend::UnmapSharedMemory(void* shm_addr, size_t byte_size)
{
RETURN_IF_TRITON_ERROR(tc::UnmapSharedMemory(shm_addr, byte_size));
return Error::Success;
}
void
TritonClientBackend::ParseInferInputToTriton(
const std::vector<InferInput*>& inputs,
std::vector<tc::InferInput*>* triton_inputs)
{
for (const auto input : inputs) {
tc::InferInput* triton_input{dynamic_cast<TritonInferInput*>(input)->Get()};
triton_input->SetBinaryData(input_tensor_format_ == TensorFormat::BINARY);
triton_inputs->push_back(triton_input);
}
}
void
TritonClientBackend::ParseInferRequestedOutputToTriton(
const std::vector<const InferRequestedOutput*>& outputs,
std::vector<const tc::InferRequestedOutput*>* triton_outputs)
{
for (const auto output : outputs) {
tc::InferRequestedOutput* triton_output{
dynamic_cast<const TritonInferRequestedOutput*>(output)->Get()};
triton_output->SetBinaryData(input_tensor_format_ == TensorFormat::BINARY);
triton_outputs->push_back(triton_output);
}
}
void
TritonClientBackend::ParseInferOptionsToTriton(
const InferOptions& options, tc::InferOptions* triton_options)
{
triton_options->model_version_ = options.model_version_;
triton_options->request_id_ = options.request_id_;
if ((options.sequence_id_ != 0) || (options.sequence_id_str_ != "")) {
if (options.sequence_id_ != 0) {
triton_options->sequence_id_ = options.sequence_id_;
} else {
triton_options->sequence_id_str_ = options.sequence_id_str_;
}
triton_options->sequence_start_ = options.sequence_start_;
triton_options->sequence_end_ = options.sequence_end_;
}
triton_options->triton_enable_empty_final_response_ =
options.triton_enable_empty_final_response_;
}
void
TritonClientBackend::ParseStatistics(
const inference::ModelStatisticsResponse& infer_stat,
std::map<ModelIdentifier, ModelStatistics>* model_stats)
{
model_stats->clear();
for (const auto& this_stat : infer_stat.model_stats()) {
auto it = model_stats
->emplace(
std::make_pair(this_stat.name(), this_stat.version()),
ModelStatistics())
.first;
it->second.inference_count_ = this_stat.inference_count();
it->second.execution_count_ = this_stat.execution_count();
it->second.success_count_ = this_stat.inference_stats().success().count();
it->second.queue_count_ = this_stat.inference_stats().queue().count();
it->second.compute_input_count_ =
this_stat.inference_stats().compute_input().count();
it->second.compute_infer_count_ =
this_stat.inference_stats().compute_infer().count();
it->second.compute_output_count_ =
this_stat.inference_stats().compute_output().count();
it->second.cumm_time_ns_ = this_stat.inference_stats().success().ns();
it->second.queue_time_ns_ = this_stat.inference_stats().queue().ns();
it->second.compute_input_time_ns_ =
this_stat.inference_stats().compute_input().ns();
it->second.compute_infer_time_ns_ =
this_stat.inference_stats().compute_infer().ns();
it->second.compute_output_time_ns_ =
this_stat.inference_stats().compute_output().ns();
it->second.cache_hit_count_ =
this_stat.inference_stats().cache_hit().count();
it->second.cache_hit_time_ns_ =
this_stat.inference_stats().cache_hit().ns();
it->second.cache_miss_count_ =
this_stat.inference_stats().cache_miss().count();
it->second.cache_miss_time_ns_ =
this_stat.inference_stats().cache_miss().ns();
}
}
void
TritonClientBackend::ParseStatistics(
const rapidjson::Document& infer_stat,
std::map<ModelIdentifier, ModelStatistics>* model_stats)
{
model_stats->clear();
for (const auto& this_stat : infer_stat["model_stats"].GetArray()) {
auto it = model_stats
->emplace(
std::make_pair(
this_stat["name"].GetString(),
this_stat["version"].GetString()),
ModelStatistics())
.first;
it->second.inference_count_ = this_stat["inference_count"].GetUint64();
it->second.execution_count_ = this_stat["execution_count"].GetUint64();
it->second.success_count_ =
this_stat["inference_stats"]["success"]["count"].GetUint64();
it->second.queue_count_ =
this_stat["inference_stats"]["queue"]["count"].GetUint64();
it->second.compute_input_count_ =
this_stat["inference_stats"]["compute_input"]["count"].GetUint64();
it->second.compute_infer_count_ =
this_stat["inference_stats"]["compute_infer"]["count"].GetUint64();
it->second.compute_output_count_ =
this_stat["inference_stats"]["compute_output"]["count"].GetUint64();
it->second.cumm_time_ns_ =
this_stat["inference_stats"]["success"]["ns"].GetUint64();
it->second.queue_time_ns_ =
this_stat["inference_stats"]["queue"]["ns"].GetUint64();
it->second.compute_input_time_ns_ =
this_stat["inference_stats"]["compute_input"]["ns"].GetUint64();
it->second.compute_infer_time_ns_ =
this_stat["inference_stats"]["compute_infer"]["ns"].GetUint64();
it->second.compute_output_time_ns_ =
this_stat["inference_stats"]["compute_output"]["ns"].GetUint64();
it->second.cache_hit_count_ =
this_stat["inference_stats"]["cache_hit"]["count"].GetUint64();
it->second.cache_hit_time_ns_ =
this_stat["inference_stats"]["cache_hit"]["ns"].GetUint64();
it->second.cache_miss_count_ =
this_stat["inference_stats"]["cache_miss"]["count"].GetUint64();
it->second.cache_miss_time_ns_ =
this_stat["inference_stats"]["cache_miss"]["ns"].GetUint64();
}
}
void
TritonClientBackend::ParseInferStat(
const tc::InferStat& triton_infer_stat, InferStat* infer_stat)
{
infer_stat->completed_request_count =
triton_infer_stat.completed_request_count;
infer_stat->cumulative_total_request_time_ns =
triton_infer_stat.cumulative_total_request_time_ns;
infer_stat->cumulative_send_time_ns =
triton_infer_stat.cumulative_send_time_ns;
infer_stat->cumulative_receive_time_ns =
triton_infer_stat.cumulative_receive_time_ns;
}
//==============================================================================
Error
TritonInferInput::Create(
InferInput** infer_input, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype)
{
TritonInferInput* local_infer_input = new TritonInferInput(name, datatype);
tc::InferInput* triton_infer_input;
RETURN_IF_TRITON_ERROR(
tc::InferInput::Create(&triton_infer_input, name, dims, datatype));
local_infer_input->input_.reset(triton_infer_input);
*infer_input = local_infer_input;
return Error::Success;
}
const std::vector<int64_t>&
TritonInferInput::Shape() const
{
return input_->Shape();
}
Error
TritonInferInput::SetShape(const std::vector<int64_t>& shape)
{
RETURN_IF_TRITON_ERROR(input_->SetShape(shape));
return Error::Success;
}
Error
TritonInferInput::Reset()
{
RETURN_IF_TRITON_ERROR(input_->Reset());
return Error::Success;
}
Error
TritonInferInput::AppendRaw(const uint8_t* input, size_t input_byte_size)
{
RETURN_IF_TRITON_ERROR(input_->AppendRaw(input, input_byte_size));
return Error::Success;
}
Error
TritonInferInput::SetSharedMemory(
const std::string& name, size_t byte_size, size_t offset)
{
RETURN_IF_TRITON_ERROR(input_->SetSharedMemory(name, byte_size, offset));
return Error::Success;
}
TritonInferInput::TritonInferInput(
const std::string& name, const std::string& datatype)
: InferInput(BackendKind::TRITON, name, datatype)
{
}
//==============================================================================
Error
TritonInferRequestedOutput::Create(
InferRequestedOutput** infer_output, const std::string& name,
const size_t class_count)
{
TritonInferRequestedOutput* local_infer_output =
new TritonInferRequestedOutput(name);
tc::InferRequestedOutput* triton_infer_output;
RETURN_IF_TRITON_ERROR(tc::InferRequestedOutput::Create(
&triton_infer_output, name, class_count));
local_infer_output->output_.reset(triton_infer_output);
*infer_output = local_infer_output;
return Error::Success;
}
Error
TritonInferRequestedOutput::SetSharedMemory(
const std::string& region_name, const size_t byte_size, const size_t offset)
{
RETURN_IF_TRITON_ERROR(
output_->SetSharedMemory(region_name, byte_size, offset));
return Error::Success;
}
TritonInferRequestedOutput::TritonInferRequestedOutput(const std::string& name)
: InferRequestedOutput(BackendKind::TRITON, name)
{
}
//==============================================================================
TritonInferResult::TritonInferResult(tc::InferResult* result)
{
result_.reset(result);
}
Error
TritonInferResult::Id(std::string* id) const
{
RETURN_IF_TRITON_ERROR(result_->Id(id));
return Error::Success;
}
Error
TritonInferResult::RequestStatus() const
{
RETURN_IF_TRITON_ERROR(result_->RequestStatus());
return Error::Success;
}
Error
TritonInferResult::RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const
{
RETURN_IF_TRITON_ERROR(result_->RawData(output_name, buf, byte_size));
return Error::Success;
}
Error
TritonInferResult::IsFinalResponse(bool* is_final_response) const
{
RETURN_IF_TRITON_ERROR(result_->IsFinalResponse(is_final_response));
return Error::Success;
}
Error
TritonInferResult::IsNullResponse(bool* is_null_response) const
{
RETURN_IF_TRITON_ERROR(result_->IsNullResponse(is_null_response));
return Error::Success;
}
//==============================================================================
}}}} // namespace triton::perfanalyzer::clientbackend::tritonremote
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <cstdint>
#include <map>
#include <regex>
#include <string>
#include <type_traits>
#include "../../constants.h"
#include "../../metrics.h"
#include "../../perf_utils.h"
#include "../client_backend.h"
#include "grpc_client.h"
#include "http_client.h"
#include "shm_utils.h"
#define RETURN_IF_TRITON_ERROR(S) \
do { \
const tc::Error& status__ = (S); \
if (!status__.IsOk()) { \
return Error(status__.Message(), pa::GENERIC_ERROR); \
} \
} while (false)
#define FAIL_IF_TRITON_ERR(X, MSG) \
{ \
const tc::Error err = (X); \
if (!err.IsOk()) { \
std::cerr << "error: " << (MSG) << ": " << err << std::endl; \
exit(pa::GENERIC_ERROR); \
} \
}
namespace tc = triton::client;
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tritonremote {
#ifndef DOCTEST_CONFIG_DISABLE
class TestTritonClientBackend;
#endif
//==============================================================================
/// TritonClientBackend uses triton client C++ library to communicate with
/// triton inference service.
///
class TritonClientBackend : public ClientBackend {
public:
/// Create a triton client backend which can be used to interact with the
/// server.
/// \param url The inference server url and port.
/// \param protocol The protocol type used.
/// \param ssl_options The SSL options used with client backend.
/// \param http_headers Map of HTTP headers. The map key/value indicates
/// the header name/value.
/// \param verbose Enables the verbose mode.
/// \param metrics_url The inference server metrics url and port.
/// \param input_tensor_format The Triton inference request input tensor
/// format.
/// \param output_tensor_format The Triton inference response output tensor
/// format.
/// \param client_backend Returns a new TritonClientBackend object.
/// \return Error object indicating success or failure.
static Error Create(
const std::string& url, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const grpc_compression_algorithm compression_algorithm,
std::shared_ptr<tc::Headers> http_headers, const bool verbose,
const std::string& metrics_url,
const cb::TensorFormat input_tensor_format,
const cb::TensorFormat output_tensor_format,
std::unique_ptr<ClientBackend>* client_backend);
/// See ClientBackend::ServerExtensions()
Error ServerExtensions(std::set<std::string>* server_extensions) override;
/// See ClientBackend::ModelMetadata()
Error ModelMetadata(
rapidjson::Document* model_metadata, const std::string& model_name,
const std::string& model_version) override;
/// See ClientBackend::ModelConfig()
Error ModelConfig(
rapidjson::Document* model_config, const std::string& model_name,
const std::string& model_version) override;
/// See ClientBackend::Infer()
Error Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override;
/// See ClientBackend::AsyncInfer()
Error AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override;
/// See ClientBackend::StartStream()
Error StartStream(OnCompleteFn callback, bool enable_stats) override;
/// See ClientBackend::AsyncStreamInfer()
Error AsyncStreamInfer(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override;
/// See ClientBackend::ClientInferStat()
Error ClientInferStat(InferStat* infer_stat) override;
/// See ClientBackend::ModelInferenceStatistics()
Error ModelInferenceStatistics(
std::map<ModelIdentifier, ModelStatistics>* model_stats,
const std::string& model_name = "",
const std::string& model_version = "") override;
/// See ClientBackend::Metrics()
Error Metrics(triton::perfanalyzer::Metrics& metrics) override;
/// See ClientBackend::UnregisterAllSharedMemory()
Error UnregisterAllSharedMemory() override;
/// See ClientBackend::RegisterSystemSharedMemory()
Error RegisterSystemSharedMemory(
const std::string& name, const std::string& key,
const size_t byte_size) override;
/// See ClientBackend::RegisterCudaSharedMemory()
Error RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& handle,
const size_t byte_size) override;
/// See ClientBackend::CreateSharedMemoryRegion()
Error CreateSharedMemoryRegion(
std::string shm_key, size_t byte_size, int* shm_fd) override;
/// See ClientBackend::MapSharedMemory()
Error MapSharedMemory(
int shm_fd, size_t offset, size_t byte_size, void** shm_addr) override;
/// See ClientBackend::CloseSharedMemory()
Error CloseSharedMemory(int shm_fd) override;
/// See ClientBackend::UnlinkSharedMemoryRegion()
Error UnlinkSharedMemoryRegion(std::string shm_key) override;
/// See ClientBackend::UnmapSharedMemory()
Error UnmapSharedMemory(void* shm_addr, size_t byte_size) override;
private:
TritonClientBackend(
const ProtocolType protocol,
const grpc_compression_algorithm compression_algorithm,
std::shared_ptr<tc::Headers> http_headers, const std::string& metrics_url,
const cb::TensorFormat input_tensor_format,
const cb::TensorFormat output_tensor_format)
: ClientBackend(BackendKind::TRITON), protocol_(protocol),
compression_algorithm_(compression_algorithm),
http_headers_(http_headers), metrics_url_(metrics_url),
input_tensor_format_(input_tensor_format),
output_tensor_format_(output_tensor_format)
{
}
void ParseInferInputToTriton(
const std::vector<InferInput*>& inputs,
std::vector<tc::InferInput*>* triton_inputs);
void ParseInferRequestedOutputToTriton(
const std::vector<const InferRequestedOutput*>& outputs,
std::vector<const tc::InferRequestedOutput*>* triton_outputs);
void ParseInferOptionsToTriton(
const InferOptions& options, tc::InferOptions* triton_options);
void ParseStatistics(
const inference::ModelStatisticsResponse& infer_stat,
std::map<ModelIdentifier, ModelStatistics>* model_stats);
void ParseStatistics(
const rapidjson::Document& infer_stat,
std::map<ModelIdentifier, ModelStatistics>* model_stats);
void ParseInferStat(
const tc::InferStat& triton_infer_stat, InferStat* infer_stat);
void AccessMetricsEndpoint(std::string& metrics_endpoint_text);
void ParseAndStoreMetrics(
const std::string& metrics_endpoint_text,
triton::perfanalyzer::Metrics& metrics);
template <typename T>
void ParseAndStoreMetric(
const std::string& metrics_endpoint_text, const std::string metric_id,
std::map<std::string, T>& metric_per_gpu)
{
std::regex metric_regex(
R"(\n)" + metric_id + R"(\{gpu_uuid\=\"([^"]+)\"\} (\d+\.?\d*))");
std::sregex_iterator metric_regex_match_begin{std::sregex_iterator(
metrics_endpoint_text.begin(), metrics_endpoint_text.end(),
metric_regex)};
for (std::sregex_iterator i{metric_regex_match_begin};
i != std::sregex_iterator(); i++) {
const std::smatch& match{*i};
const std::string& gpu_uuid{match[1].str()};
T metric{};
if (std::is_same<T, double>::value) {
metric = std::stod(match[2].str());
} else if (std::is_same<T, uint64_t>::value) {
metric = static_cast<uint64_t>(std::stod(match[2].str()));
}
metric_per_gpu[gpu_uuid] = metric;
}
}
/// Union to represent the underlying triton client belonging to one of
/// the protocols
union TritonClient {
TritonClient()
{
new (&http_client_) std::unique_ptr<tc::InferenceServerHttpClient>{};
}
~TritonClient() {}
std::unique_ptr<tc::InferenceServerHttpClient> http_client_;
std::unique_ptr<tc::InferenceServerGrpcClient> grpc_client_;
} client_;
const ProtocolType protocol_{UNKNOWN};
const grpc_compression_algorithm compression_algorithm_{GRPC_COMPRESS_NONE};
std::shared_ptr<tc::Headers> http_headers_;
const std::string metrics_url_{""};
const cb::TensorFormat input_tensor_format_{cb::TensorFormat::UNKNOWN};
const cb::TensorFormat output_tensor_format_{cb::TensorFormat::UNKNOWN};
#ifndef DOCTEST_CONFIG_DISABLE
friend TestTritonClientBackend;
public:
TritonClientBackend() = default;
#endif
};
//==============================================================
/// TritonInferInput is a wrapper around InferInput object of
/// triton client library.
///
class TritonInferInput : public InferInput {
public:
static Error Create(
InferInput** infer_input, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype);
/// Returns the raw InferInput object required by triton client library.
tc::InferInput* Get() const { return input_.get(); }
/// See InferInput::Shape()
const std::vector<int64_t>& Shape() const override;
/// See InferInput::SetShape()
Error SetShape(const std::vector<int64_t>& shape) override;
/// See InferInput::Reset()
Error Reset() override;
/// See InferInput::AppendRaw()
Error AppendRaw(const uint8_t* input, size_t input_byte_size) override;
/// See InferInput::SetSharedMemory()
Error SetSharedMemory(
const std::string& name, size_t byte_size, size_t offset = 0) override;
private:
explicit TritonInferInput(
const std::string& name, const std::string& datatype);
std::unique_ptr<tc::InferInput> input_;
};
//==============================================================
/// TritonInferRequestedOutput is a wrapper around
/// InferRequestedOutput object of triton client library.
///
class TritonInferRequestedOutput : public InferRequestedOutput {
public:
static Error Create(
InferRequestedOutput** infer_output, const std::string& name,
const size_t class_count = 0);
/// Returns the raw InferRequestedOutput object required by triton client
/// library.
tc::InferRequestedOutput* Get() const { return output_.get(); }
// See InferRequestedOutput::SetSharedMemory()
Error SetSharedMemory(
const std::string& region_name, const size_t byte_size,
const size_t offset = 0) override;
private:
explicit TritonInferRequestedOutput(const std::string& name);
std::unique_ptr<tc::InferRequestedOutput> output_;
};
//==============================================================
/// TritonInferResult is a wrapper around InferResult object of
/// triton client library.
///
class TritonInferResult : public InferResult {
public:
explicit TritonInferResult(tc::InferResult* result);
/// See InferResult::Id()
Error Id(std::string* id) const override;
/// See InferResult::RequestStatus()
Error RequestStatus() const override;
/// See InferResult::RawData()
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
/// See InferResult::IsFinalResponse()
Error IsFinalResponse(bool* is_final_response) const override;
/// See InferResult::IsNullResponse()
Error IsNullResponse(bool* is_null_response) const override;
private:
std::unique_ptr<tc::InferResult> result_;
};
}}}} // namespace triton::perfanalyzer::clientbackend::tritonremote
# Copyright 2021-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18)
set(
TRITON_C_API_CLIENT_BACKEND_SRCS
triton_c_api_backend.cc
shared_library.cc
triton_loader.cc
shared_memory_manager.cc
scoped_defer.cc
)
set(
TRITON_C_API_CLIENT_BACKEND_HDRS
triton_c_api_backend.h
shared_library.h
shared_memory_manager.h
triton_loader.h
c_api_infer_results.h
scoped_defer.h
)
add_library(
triton-c-api-backend-library EXCLUDE_FROM_ALL OBJECT
${TRITON_C_API_CLIENT_BACKEND_SRCS}
${TRITON_C_API_CLIENT_BACKEND_HDRS}
)
target_link_libraries(
triton-c-api-backend-library
grpcclient_static
httpclient_static
triton-core-serverapi # from repo-core
)
// Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "common.h"
namespace tc = triton::client;
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tritoncapi {
/// This class is used to pass inference status and id to upstream backend.
/// Created so that the API is similar to `triton, torchserver,
/// tensorflow_serving` APIs
class InferResult {
public:
static void Create(
InferResult** infer_result, const tc::Error& err, const std::string& id)
{
*infer_result = reinterpret_cast<InferResult*>(new InferResult(err, id));
}
tc::Error Id(std::string* id) const
{
*id = request_id_;
return tc::Error::Success;
}
tc::Error RequestStatus() const { return status_; }
private:
InferResult(const tc::Error& err, const std::string& id)
: status_(err), request_id_(id)
{
}
std::string request_id_;
tc::Error status_;
};
}}}} // namespace triton::perfanalyzer::clientbackend::tritoncapi
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment