Commit c68e1835 authored by lijian6's avatar lijian6
Browse files

Initial commit

parents
Pipeline #561 failed with stages
in 0 seconds
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
/// \file
#include <algorithm>
#include <chrono>
#include <condition_variable>
#include <cstring>
#include <functional>
#include <iostream>
#include <list>
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#ifdef TRITON_INFERENCE_SERVER_CLIENT_CLASS
namespace triton { namespace perfanalyzer { namespace clientbackend {
namespace tritoncapi {
class TritonLoader;
}}}} // namespace triton::perfanalyzer::clientbackend::tritoncapi
#endif
namespace triton { namespace client {
constexpr char kInferHeaderContentLengthHTTPHeader[] =
"Inference-Header-Content-Length";
constexpr int MAX_GRPC_MESSAGE_SIZE = INT32_MAX;
class InferResult;
class InferRequest;
class RequestTimers;
//==============================================================================
/// Error status reported by client API.
///
class Error {
public:
/// Create an error with the specified message.
/// \param msg The message for the error
explicit Error(const std::string& msg = "");
/// Accessor for the message of this error.
/// \return The message for the error. Empty if no error.
const std::string& Message() const { return msg_; }
/// Does this error indicate OK status?
/// \return True if this error indicates "ok"/"success", false if
/// error indicates a failure.
bool IsOk() const { return msg_.empty(); }
/// Convenience "success" value. Can be used as Error::Success to
/// indicate no error.
static const Error Success;
private:
friend std::ostream& operator<<(std::ostream&, const Error&);
std::string msg_;
};
//==============================================================================
/// Cumulative inference statistics.
///
/// \note
/// For GRPC protocol, 'cumulative_send_time_ns' represents the
/// time for marshaling infer request.
/// 'cumulative_receive_time_ns' represents the time for
/// unmarshaling infer response.
struct InferStat {
/// Total number of requests completed.
size_t completed_request_count;
/// Time from the request start until the response is completely
/// received.
uint64_t cumulative_total_request_time_ns;
/// Time from the request start until the last byte is sent.
uint64_t cumulative_send_time_ns;
/// Time from receiving first byte of the response until the
/// response is completely received.
uint64_t cumulative_receive_time_ns;
/// Create a new InferStat object with zero-ed statistics.
InferStat()
: completed_request_count(0), cumulative_total_request_time_ns(0),
cumulative_send_time_ns(0), cumulative_receive_time_ns(0)
{
}
};
//==============================================================================
/// The base class for InferenceServerClients
///
class InferenceServerClient {
public:
using OnCompleteFn = std::function<void(InferResult*)>;
using OnMultiCompleteFn = std::function<void(std::vector<InferResult*>)>;
explicit InferenceServerClient(bool verbose)
: verbose_(verbose), exiting_(false)
{
}
virtual ~InferenceServerClient() = default;
/// Obtain the cumulative inference statistics of the client.
/// \param Returns the InferStat object holding current statistics.
/// \return Error object indicating success or failure.
Error ClientInferStat(InferStat* infer_stat) const;
protected:
// Update the infer stat with the given timer
Error UpdateInferStat(const RequestTimers& timer);
// Enables verbose operation in the client.
bool verbose_;
// worker thread that will perform the asynchronous transfer
std::thread worker_;
// Avoid race condition between main thread and worker thread
std::mutex mutex_;
// Condition variable used for waiting on asynchronous request
std::condition_variable cv_;
// signal for worker thread to stop
bool exiting_;
// The inference statistic of the current client
InferStat infer_stat_;
};
//==============================================================================
/// Structure to hold options for Inference Request.
///
struct InferOptions {
explicit InferOptions(const std::string& model_name)
: model_name_(model_name), model_version_(""), request_id_(""),
sequence_id_(0), sequence_id_str_(""), sequence_start_(false),
sequence_end_(false), priority_(0), server_timeout_(0),
client_timeout_(0), triton_enable_empty_final_response_(false)
{
}
/// The name of the model to run inference.
std::string model_name_;
/// The version of the model to use while running inference. The default
/// value is an empty string which means the server will select the
/// version of the model based on its internal policy.
std::string model_version_;
/// An identifier for the request. If specified will be returned
/// in the response. Default value is an empty string which means no
/// request_id will be used.
std::string request_id_;
/// The unique identifier for the sequence being represented by the
/// object. Default value is 0 which means that the request does not
/// belong to a sequence. If this value is non-zero, then sequence_id_str_
/// MUST be set to "".
uint64_t sequence_id_;
/// The unique identifier for the sequence being represented by the
/// object. Default value is "" which means that the request does not
/// belong to a sequence. If this value is non-empty, then sequence_id_
/// MUST be set to 0.
std::string sequence_id_str_;
/// Indicates whether the request being added marks the start of the
/// sequence. Default value is False. This argument is ignored if
/// 'sequence_id' is 0.
bool sequence_start_;
/// Indicates whether the request being added marks the end of the
/// sequence. Default value is False. This argument is ignored if
/// 'sequence_id' is 0.
bool sequence_end_;
/// Indicates the priority of the request. Priority value zero
/// indicates that the default priority level should be used
/// (i.e. same behavior as not specifying the priority parameter).
/// Lower value priorities indicate higher priority levels. Thus
/// the highest priority level is indicated by setting the parameter
/// to 1, the next highest is 2, etc. If not provided, the server
/// will handle the request using default setting for the model.
uint64_t priority_;
/// The timeout value for the request, in microseconds. If the request
/// cannot be completed within the time by the server. The server can take a
/// model-specific action such as terminating the request. If not
/// provided, the server will handle the request using default setting
/// for the model. This option is only respected by the model that is
/// configured with dynamic batching. See here for more details:
/// https://github.com/triton-inference-server/server/blob/main/docs/user_guide/model_configuration.md#dynamic-batcher
uint64_t server_timeout_;
// The maximum end-to-end time, in microseconds, the request is allowed
// to take. The client will abort request when the specified time elapses.
// The request will return error with message "Deadline Exceeded".
// The default value is 0 which means client will wait for the
// response from the server. This option is not supported for streaming
// requests. Instead see 'stream_timeout' argument in
// InferenceServerGrpcClient::StartStream().
// NOTE: the HTTP client library only offers millisecond precision, so a
// timeout < 1000 microseconds will be rounded down to 0 milliseconds and have
// no effect.
uint64_t client_timeout_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;
};
//==============================================================================
/// An interface for InferInput object to describe the model input for
/// inference.
///
class InferInput {
public:
/// Create a InferInput instance that describes a model input.
/// \param infer_input Returns a new InferInput object.
/// \param name The name of input whose data will be described by this object.
/// \param dims The shape of the input.
/// \param datatype The datatype of the input.
/// \return Error object indicating success or failure.
static Error Create(
InferInput** infer_input, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype);
/// Gets name of the associated input tensor.
/// \return The name of the tensor.
const std::string& Name() const { return name_; }
/// Gets datatype of the associated input tensor.
/// \return The datatype of the tensor.
const std::string& Datatype() const { return datatype_; }
/// Gets the shape of the input tensor.
/// \return The shape of the tensor.
const std::vector<int64_t>& Shape() const { return shape_; }
/// Set the shape of input associated with this object.
/// \param dims the vector of dims representing the new shape
/// of input.
/// \return Error object indicating success or failure of the
/// request.
Error SetShape(const std::vector<int64_t>& dims);
/// Prepare this input to receive new tensor values. Forget any
/// existing values that were set by previous calls to SetSharedMemory()
/// or AppendRaw().
/// \return Error object indicating success or failure.
Error Reset();
/// Append tensor values for this input from a byte vector. The vector
/// is not copied and so it must not be modified or destroyed
/// until this input is no longer needed (that is until the Infer()
/// call(s) that use the input have completed). Multiple calls can
/// be made to this API to keep adding tensor data for this input.
/// The data will be delivered in the order it was added.
/// \param input The vector holding tensor values.
/// \return Error object indicating success or failure.
Error AppendRaw(const std::vector<uint8_t>& input);
/// Append tensor values for this input from a byte array. The array
/// is not copied and so it must not be modified or destroyed
/// until this input is no longer needed (that is until the Infer()
/// call(s) that use the input have completed). Multiple calls can
/// be made to this API to keep adding tensor data for this input.
/// The data will be delivered in the order it was added.
/// \param input The pointer to the array holding the tensor value.
/// \param input_byte_size The size of the array in bytes.
/// \return Error object indicating success or failure.
Error AppendRaw(const uint8_t* input, size_t input_byte_size);
/// Set tensor values for this input by reference into a shared memory
/// region. The values are not copied and so the shared memory region and
/// its contents must not be modified or destroyed until this input is no
/// longer needed (that is until the Infer() call(s) that use the input have
/// completed. This function must be called a single time for an input that
/// is using shared memory. The entire tensor data required by this input
/// must be contiguous in a single shared memory region.
/// \param name The user-given name for the registered shared memory region
/// where the tensor values for this input is stored.
/// \param byte_size The size, in bytes of the input tensor data. Must
/// match the size expected for the input shape.
/// \param offset The offset into the shared memory region upto the start
/// of the input tensor values. The default value is 0.
/// \return Error object indicating success or failure
Error SetSharedMemory(
const std::string& name, size_t byte_size, size_t offset = 0);
/// \return true if this input is being provided in shared memory.
bool IsSharedMemory() const { return (io_type_ == SHARED_MEMORY); }
/// Get information about the shared memory being used for this
/// input.
/// \param name Returns the name of the shared memory region.
/// \param byte_size Returns the size, in bytes, of the shared
/// memory region.
/// \param offset Returns the offset within the shared memory
/// region.
/// \return Error object indicating success or failure.
Error SharedMemoryInfo(
std::string* name, size_t* byte_size, size_t* offset) const;
/// Append tensor values for this input from a vector or
/// strings. This method can only be used for tensors with BYTES
/// data-type. The strings are assigned in row-major order to the
/// elements of the tensor. The strings are copied and so the
/// 'input' does not need to be preserved as with AppendRaw(). Multiple
/// calls can be made to this API to keep adding tensor data for
/// this input. The data will be delivered in the order it was added.
/// \param input The vector holding tensor string values.
/// \return Error object indicating success or failure.
Error AppendFromString(const std::vector<std::string>& input);
/// Gets the size of data added into this input in bytes.
/// \param byte_size The size of data added in bytes.
/// \return Error object indicating success or failure.
Error ByteSize(size_t* byte_size) const;
/// \return true if this input should be sent in binary format.
bool BinaryData() const { return binary_data_; }
/// \return Error object indicating success or failure.
Error SetBinaryData(const bool binary_data);
private:
#ifdef TRITON_INFERENCE_SERVER_CLIENT_CLASS
friend class TRITON_INFERENCE_SERVER_CLIENT_CLASS;
#endif
friend class HttpInferRequest;
InferInput(
const std::string& name, const std::vector<int64_t>& dims,
const std::string& datatype);
Error PrepareForRequest();
Error GetNext(
uint8_t* buf, size_t size, size_t* input_bytes, bool* end_of_input);
Error GetNext(const uint8_t** buf, size_t* input_bytes, bool* end_of_input);
std::string name_;
std::vector<int64_t> shape_;
std::string datatype_;
size_t byte_size_;
size_t bufs_idx_, buf_pos_;
std::vector<const uint8_t*> bufs_;
std::vector<size_t> buf_byte_sizes_;
// Used only for STRING type tensors set with SetFromString(). Hold
// the "raw" serialization of the string values for each index
// that are then referenced by 'bufs_'. A std::list is used to avoid
// reallocs that could invalidate the pointer references into the
// std::string objects.
std::list<std::string> str_bufs_;
// Used only if working with Shared Memory
enum IOType { NONE, RAW, SHARED_MEMORY };
IOType io_type_;
std::string shm_name_;
size_t shm_offset_;
bool binary_data_{true};
};
//==============================================================================
/// An InferRequestedOutput object is used to describe the requested model
/// output for inference.
///
class InferRequestedOutput {
public:
/// Create a InferRequestedOutput instance that describes a model output being
/// requested.
/// \param infer_output Returns a new InferOutputGrpc object.
/// \param name The name of output being requested.
/// \param class_count The number of classifications to be requested. The
/// default value is 0 which means the classification results are not
/// requested.
/// \return Error object indicating success or failure.
static Error Create(
InferRequestedOutput** infer_output, const std::string& name,
const size_t class_count = 0);
/// Gets name of the associated output tensor.
/// \return The name of the tensor.
const std::string& Name() const { return name_; }
/// Get the number of classifications requested for this output, or
/// 0 if the output is not being returned as classifications.
size_t ClassificationCount() const { return class_count_; }
/// Set the output tensor data to be written to specified shared
/// memory region.
/// \param region_name The name of the shared memory region.
/// \param byte_size The size of data in bytes.
/// \param offset The offset in shared memory region. Default value is 0.
/// \return Error object indicating success or failure of the
/// request.
Error SetSharedMemory(
const std::string& region_name, const size_t byte_size,
const size_t offset = 0);
/// Clears the shared memory option set by the last call to
/// InferRequestedOutput::SetSharedMemory(). After call to this
/// function requested output will no longer be returned in a
/// shared memory region.
/// \return Error object indicating success or failure of the
/// request.
Error UnsetSharedMemory();
/// \return true if this output is being returned in shared memory.
bool IsSharedMemory() const { return (io_type_ == SHARED_MEMORY); }
/// Get information about the shared memory being used for this
/// output.
/// \param name Returns the name of the shared memory region.
/// \param byte_size Returns the size, in bytes, of the shared
/// memory region.
/// \param offset Returns the offset within the shared memory
/// region.
/// \return Error object indicating success or failure.
Error SharedMemoryInfo(
std::string* name, size_t* byte_size, size_t* offset) const;
/// \return true if this output should be received in binary format.
bool BinaryData() const { return binary_data_; }
/// \return Error object indicating success or failure.
Error SetBinaryData(const bool binary_data);
private:
#ifdef TRITON_INFERENCE_SERVER_CLIENT_CLASS
friend class TRITON_INFERENCE_SERVER_CLIENT_CLASS;
#endif
explicit InferRequestedOutput(
const std::string& name, const size_t class_count = 0);
std::string name_;
size_t class_count_;
// Used only if working with Shared Memory
enum IOType { NONE, RAW, SHARED_MEMORY };
IOType io_type_;
std::string shm_name_;
size_t shm_byte_size_;
size_t shm_offset_;
bool binary_data_{true};
};
//==============================================================================
/// An interface for InferResult object to interpret the response to an
/// inference request.
///
class InferResult {
public:
virtual ~InferResult() = default;
/// Get the name of the model which generated this response.
/// \param name Returns the name of the model.
/// \return Error object indicating success or failure.
virtual Error ModelName(std::string* name) const = 0;
/// Get the version of the model which generated this response.
/// \param version Returns the version of the model.
/// \return Error object indicating success or failure.
virtual Error ModelVersion(std::string* version) const = 0;
/// Get the id of the request which generated this response.
/// \param version Returns the version of the model.
/// \return Error object indicating success or failure.
virtual Error Id(std::string* id) const = 0;
/// Get the shape of output result returned in the response.
/// \param output_name The name of the output to get shape.
/// \param shape Returns the shape of result for specified output name.
/// \return Error object indicating success or failure.
virtual Error Shape(
const std::string& output_name, std::vector<int64_t>* shape) const = 0;
/// Get the datatype of output result returned in the response.
/// \param output_name The name of the output to get datatype.
/// \param shape Returns the datatype of result for specified output name.
/// \return Error object indicating success or failure.
virtual Error Datatype(
const std::string& output_name, std::string* datatype) const = 0;
/// Get access to the buffer holding raw results of specified output
/// returned by the server. Note the buffer is owned by InferResult
/// instance. Users can copy out the data if required to extend the
/// lifetime.
/// \param output_name The name of the output to get result data.
/// \param buf Returns the pointer to the start of the buffer.
/// \param byte_size Returns the size of buffer in bytes.
/// \return Error object indicating success or failure of the
/// request.
virtual Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const = 0;
/// Get final response bool for this response.
/// \return Error object indicating the success or failure.
virtual Error IsFinalResponse(bool* is_final_response) const = 0;
/// Get null response bool for this response.
/// \return Error object indicating the success or failure.
virtual Error IsNullResponse(bool* is_null_response) const = 0;
/// Get the result data as a vector of strings. The vector will
/// receive a copy of result data. An error will be generated if
/// the datatype of output is not 'BYTES'.
/// \param output_name The name of the output to get result data.
/// \param string_result Returns the result data represented as
/// a vector of strings. The strings are stored in the
/// row-major order.
/// \return Error object indicating success or failure of the
/// request.
virtual Error StringData(
const std::string& output_name,
std::vector<std::string>* string_result) const = 0;
/// Returns the complete response as a user friendly string.
/// \return The string describing the complete response.
virtual std::string DebugString() const = 0;
/// Returns the status of the request.
/// \return Error object indicating the success or failure of the
/// request.
virtual Error RequestStatus() const = 0;
};
//==============================================================================
/// Records timestamps for different stages of request handling.
///
class RequestTimers {
public:
/// Timestamp kinds.
enum class Kind {
/// The start of request handling.
REQUEST_START,
/// The end of request handling.
REQUEST_END,
/// The start of sending request bytes to the server (i.e. first
/// byte).
SEND_START,
/// The end of sending request bytes to the server (i.e. last
/// byte).
SEND_END,
/// The start of receiving response bytes from the server
/// (i.e. first byte).
RECV_START,
/// The end of receiving response bytes from the server (i.e. last
/// byte).
RECV_END,
COUNT__
};
/// Construct a timer with zero-ed timestamps.
RequestTimers() : timestamps_((size_t)Kind::COUNT__) { Reset(); }
/// Reset all timestamp values to zero. Must be called before
/// re-using the timer.
void Reset()
{
memset(&timestamps_[0], 0, sizeof(uint64_t) * timestamps_.size());
}
/// Get the timestamp, in nanoseconds, for a kind.
/// \param kind The timestamp kind.
/// \return The timestamp in nanoseconds.
uint64_t Timestamp(Kind kind) const { return timestamps_[(size_t)kind]; }
/// Set a timestamp to the current time, in nanoseconds.
/// \param kind The timestamp kind.
/// \return The timestamp in nanoseconds.
uint64_t CaptureTimestamp(Kind kind)
{
uint64_t& ts = timestamps_[(size_t)kind];
ts = std::chrono::duration_cast<std::chrono::nanoseconds>(
std::chrono::high_resolution_clock::now().time_since_epoch())
.count();
return ts;
}
/// Return the duration between start time point and end timepoint
/// in nanosecond.
/// \param start The start time point.
/// \param end The end time point.
/// \return Duration in nanosecond, or
/// std::numeric_limits<uint64_t>::max to indicate that duration
/// could not be calculated.
uint64_t Duration(Kind start, Kind end) const
{
const uint64_t stime = timestamps_[(size_t)start];
const uint64_t etime = timestamps_[(size_t)end];
// If the start or end timestamp is 0 then can't calculate the
// duration, so return max to indicate error.
if ((stime == 0) || (etime == 0)) {
return (std::numeric_limits<uint64_t>::max)();
}
return (stime > etime) ? (std::numeric_limits<uint64_t>::max)()
: etime - stime;
}
private:
std::vector<uint64_t> timestamps_;
};
//==============================================================================
/// The base class to describe an inflight inference request.
///
class InferRequest {
public:
InferRequest(
InferenceServerClient::OnCompleteFn callback = nullptr,
const bool verbose = false)
: callback_(callback), verbose_(verbose)
{
}
virtual ~InferRequest() = default;
RequestTimers& Timer() { return timer_; }
protected:
InferenceServerClient::OnCompleteFn callback_;
const bool verbose_;
private:
// The timers for infer request.
RequestTimers timer_;
};
}} // namespace triton::client
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Include this first to make sure we are a friend of common classes.
#define TRITON_INFERENCE_SERVER_CLIENT_CLASS InferenceServerGrpcClient
#include "grpc_client.h"
#include <chrono>
#include <cstdint>
#include <fstream>
#include <future>
#include <iostream>
#include <mutex>
#include <sstream>
#include "common.h"
namespace triton { namespace client {
namespace {
//==============================================================================
// Use map to keep track of GRPC channels. <key, value> : <url, <shared_count,
// Channel*, Stub*>> If context is created on url that has established Channel
// and hasn't reached max shared count, then reuse it.
std::map<
std::string, std::tuple<
size_t, std::shared_ptr<grpc::Channel>,
std::shared_ptr<inference::GRPCInferenceService::Stub>>>
grpc_channel_stub_map_;
std::mutex grpc_channel_stub_map_mtx_;
std::string
GetEnvironmentVariableOrDefault(
const std::string& variable_name, const std::string& default_value)
{
const char* value = getenv(variable_name.c_str());
return value ? value : default_value;
}
void
ReadFile(const std::string& filename, std::string& data)
{
data.clear();
if (!filename.empty()) {
std::ifstream file(filename.c_str(), std::ios::in);
if (file.is_open()) {
std::stringstream ss;
ss << file.rdbuf();
file.close();
data = ss.str();
}
}
}
std::shared_ptr<inference::GRPCInferenceService::Stub>
GetStub(
const std::string& url, bool use_ssl, const SslOptions& ssl_options,
const grpc::ChannelArguments& channel_args, const bool use_cached_channel,
bool verbose)
{
std::lock_guard<std::mutex> lock(grpc_channel_stub_map_mtx_);
// Limit the number of sharing for each channel connects to the url,
// distributing clients to different channels relieves
// the pressure of reaching max connection concurrency
// https://grpc.io/docs/guides/performance/ (4th point)
static const size_t max_share_count =
std::stoul(GetEnvironmentVariableOrDefault(
"TRITON_CLIENT_GRPC_CHANNEL_MAX_SHARE_COUNT", "6"));
const auto& channel_itr = grpc_channel_stub_map_.find(url);
// Reuse cached channel if the channel is found in the map and
// used_cached_channel flag is true
if ((channel_itr != grpc_channel_stub_map_.end()) && use_cached_channel) {
// check if NewStub should be created
const auto& shared_count = std::get<0>(channel_itr->second);
if (shared_count % max_share_count != 0) {
std::get<0>(channel_itr->second)++;
return std::get<2>(channel_itr->second);
}
}
if (verbose) {
std::cout << "Creating new channel with url:" << url << std::endl;
}
// Start with a copy of channel_args param, then modify our copy as needed.
grpc::ChannelArguments arguments(channel_args);
static std::atomic<int> channel_count{0};
// Explicitly avoid channel re-use
// "channels must have different channel args to prevent re-use
// so define a use-specific channel arg such as channel number"
// https://grpc.io/docs/guides/performance/
// NOTE: The argument name "triton_client_channel_idx" is arbitrary.
arguments.SetInt("triton_client_channel_idx", channel_count.fetch_add(1));
std::shared_ptr<grpc::ChannelCredentials> credentials;
if (use_ssl) {
std::string root;
std::string key;
std::string cert;
ReadFile(ssl_options.root_certificates, root);
ReadFile(ssl_options.private_key, key);
ReadFile(ssl_options.certificate_chain, cert);
grpc::SslCredentialsOptions opts = {root, key, cert};
credentials = grpc::SslCredentials(opts);
} else {
credentials = grpc::InsecureChannelCredentials();
}
std::shared_ptr<grpc::Channel> channel =
grpc::CreateCustomChannel(url, credentials, arguments);
std::shared_ptr<inference::GRPCInferenceService::Stub> stub =
inference::GRPCInferenceService::NewStub(channel);
// Replace if channel / stub have been in the map
if (channel_itr != grpc_channel_stub_map_.end()) {
channel_itr->second = std::make_tuple(1, channel, stub);
} else {
grpc_channel_stub_map_.insert(
std::make_pair(url, std::make_tuple(1, channel, stub)));
}
return stub;
}
} // namespace
//==============================================================================
// An GrpcInferRequest represents an inflght inference request on gRPC.
//
class GrpcInferRequest : public InferRequest {
public:
GrpcInferRequest(InferenceServerClient::OnCompleteFn callback = nullptr)
: InferRequest(callback), grpc_status_(),
grpc_response_(std::make_shared<inference::ModelInferResponse>())
{
}
friend InferenceServerGrpcClient;
private:
// Variables for GRPC call
grpc::ClientContext grpc_context_;
grpc::Status grpc_status_;
std::shared_ptr<inference::ModelInferResponse> grpc_response_;
};
//==============================================================================
class InferResultGrpc : public InferResult {
public:
static Error Create(
InferResult** infer_result,
std::shared_ptr<inference::ModelInferResponse> response,
Error& request_status);
static Error Create(
InferResult** infer_result,
std::shared_ptr<inference::ModelStreamInferResponse> response);
Error RequestStatus() const override;
Error ModelName(std::string* name) const override;
Error ModelVersion(std::string* version) const override;
Error Id(std::string* id) const override;
Error Shape(const std::string& output_name, std::vector<int64_t>* shape)
const override;
Error Datatype(
const std::string& output_name, std::string* datatype) const override;
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
Error IsFinalResponse(bool* is_final_response) const override;
Error IsNullResponse(bool* is_null_response) const override;
Error StringData(
const std::string& output_name,
std::vector<std::string>* string_result) const override;
std::string DebugString() const override { return response_->DebugString(); }
private:
InferResultGrpc(
std::shared_ptr<inference::ModelInferResponse> response,
Error& request_status);
InferResultGrpc(
std::shared_ptr<inference::ModelStreamInferResponse> response);
std::map<std::string, const inference::ModelInferResponse::InferOutputTensor*>
output_name_to_tensor_map_;
std::map<std::string, std::pair<const uint8_t*, const uint32_t>>
output_name_to_buffer_map_;
std::shared_ptr<inference::ModelInferResponse> response_;
std::shared_ptr<inference::ModelStreamInferResponse> stream_response_;
Error request_status_;
bool is_final_response_{true};
bool is_null_response_{false};
};
Error
InferResultGrpc::Create(
InferResult** infer_result,
std::shared_ptr<inference::ModelInferResponse> response,
Error& request_status)
{
*infer_result = reinterpret_cast<InferResult*>(
new InferResultGrpc(response, request_status));
return Error::Success;
}
Error
InferResultGrpc::Create(
InferResult** infer_result,
std::shared_ptr<inference::ModelStreamInferResponse> response)
{
*infer_result = reinterpret_cast<InferResult*>(new InferResultGrpc(response));
return Error::Success;
}
Error
InferResultGrpc::RequestStatus() const
{
return request_status_;
}
Error
InferResultGrpc::ModelName(std::string* name) const
{
*name = response_->model_name();
return Error::Success;
}
Error
InferResultGrpc::ModelVersion(std::string* version) const
{
*version = response_->model_version();
return Error::Success;
}
Error
InferResultGrpc::Id(std::string* id) const
{
*id = response_->id();
return Error::Success;
}
Error
InferResultGrpc::Shape(
const std::string& output_name, std::vector<int64_t>* shape) const
{
shape->clear();
auto it = output_name_to_tensor_map_.find(output_name);
if (it != output_name_to_tensor_map_.end()) {
for (const auto dim : it->second->shape()) {
shape->push_back(dim);
}
} else {
return Error(
"The response does not contain shape for output name '" + output_name +
"'");
}
return Error::Success;
}
Error
InferResultGrpc::Datatype(
const std::string& output_name, std::string* datatype) const
{
auto it = output_name_to_tensor_map_.find(output_name);
if (it != output_name_to_tensor_map_.end()) {
*datatype = it->second->datatype();
} else {
return Error(
"The response does not contain datatype for output name '" +
output_name + "'");
}
return Error::Success;
}
Error
InferResultGrpc::RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const
{
auto it = output_name_to_buffer_map_.find(output_name);
if (it != output_name_to_buffer_map_.end()) {
*buf = it->second.first;
*byte_size = it->second.second;
} else {
return Error(
"The response does not contain results for output name '" +
output_name + "'");
}
return Error::Success;
}
Error
InferResultGrpc::IsFinalResponse(bool* is_final_response) const
{
if (is_final_response == nullptr) {
return Error("is_final_response cannot be nullptr");
}
*is_final_response = is_final_response_;
return Error::Success;
}
Error
InferResultGrpc::IsNullResponse(bool* is_null_response) const
{
if (is_null_response == nullptr) {
return Error("is_null_response cannot be nullptr");
}
*is_null_response = is_null_response_;
return Error::Success;
}
Error
InferResultGrpc::StringData(
const std::string& output_name,
std::vector<std::string>* string_result) const
{
std::string datatype;
Error err = Datatype(output_name, &datatype);
if (!err.IsOk()) {
return err;
}
if (datatype.compare("BYTES") != 0) {
return Error(
"This function supports tensors with datatype 'BYTES', requested "
"output tensor '" +
output_name + "' with datatype '" + datatype + "'");
}
const uint8_t* buf;
size_t byte_size;
err = RawData(output_name, &buf, &byte_size);
string_result->clear();
if (byte_size != 0) {
size_t buf_offset = 0;
while (byte_size > buf_offset) {
const uint32_t element_size =
*(reinterpret_cast<const uint32_t*>(buf + buf_offset));
string_result->emplace_back(
reinterpret_cast<const char*>(
buf + buf_offset + sizeof(element_size)),
element_size);
buf_offset += (sizeof(element_size) + element_size);
}
} else {
auto it = output_name_to_tensor_map_.find(output_name);
for (const auto& element : it->second->contents().bytes_contents()) {
string_result->push_back(element);
}
}
return Error::Success;
}
InferResultGrpc::InferResultGrpc(
std::shared_ptr<inference::ModelInferResponse> response,
Error& request_status)
: response_(response), request_status_(request_status)
{
uint32_t index = 0;
for (const auto& output : response_->outputs()) {
output_name_to_tensor_map_[output.name()] = &output;
const uint8_t* buf =
(uint8_t*)&(response_->raw_output_contents()[index][0]);
const uint32_t byte_size = response_->raw_output_contents()[index].size();
output_name_to_buffer_map_.insert(
std::make_pair(output.name(), std::make_pair(buf, byte_size)));
index++;
}
const auto& is_final_response_itr{
response_->parameters().find("triton_final_response")};
if (is_final_response_itr != response_->parameters().end()) {
is_final_response_ = is_final_response_itr->second.bool_param();
}
is_null_response_ = response_->outputs().empty() && is_final_response_;
}
InferResultGrpc::InferResultGrpc(
std::shared_ptr<inference::ModelStreamInferResponse> stream_response)
: stream_response_(stream_response)
{
request_status_ = Error(stream_response_->error_message());
response_.reset(
stream_response->mutable_infer_response(),
[](inference::ModelInferResponse*) {});
uint32_t index = 0;
for (const auto& output : response_->outputs()) {
output_name_to_tensor_map_[output.name()] = &output;
const uint8_t* buf =
(uint8_t*)&(response_->raw_output_contents()[index][0]);
const uint32_t byte_size = response_->raw_output_contents()[index].size();
output_name_to_buffer_map_.insert(
std::make_pair(output.name(), std::make_pair(buf, byte_size)));
index++;
}
const auto& is_final_response_itr{
response_->parameters().find("triton_final_response")};
if (is_final_response_itr != response_->parameters().end()) {
is_final_response_ = is_final_response_itr->second.bool_param();
}
is_null_response_ = response_->outputs().empty() && is_final_response_;
}
//==============================================================================
// Advanced users can generically pass any channel arguments
// through the channel_args parameter, including KeepAlive options.
// Channel arguments provided by the user are at the user's own risk and
// are assumed to be correct/complete.
Error
InferenceServerGrpcClient::Create(
std::unique_ptr<InferenceServerGrpcClient>* client,
const std::string& server_url, const grpc::ChannelArguments& channel_args,
bool verbose, bool use_ssl, const SslOptions& ssl_options,
const bool use_cached_channel)
{
client->reset(new InferenceServerGrpcClient(
server_url, verbose, use_ssl, ssl_options, channel_args,
use_cached_channel));
return Error::Success;
}
// Most users should use this method of creating a client unless
// they have an advanced use case that is not supported.
Error
InferenceServerGrpcClient::Create(
std::unique_ptr<InferenceServerGrpcClient>* client,
const std::string& server_url, bool verbose, bool use_ssl,
const SslOptions& ssl_options, const KeepAliveOptions& keepalive_options,
const bool use_cached_channel)
{
// Construct channel channel_args specific to Triton
grpc::ChannelArguments channel_args;
channel_args.SetMaxSendMessageSize(MAX_GRPC_MESSAGE_SIZE);
channel_args.SetMaxReceiveMessageSize(MAX_GRPC_MESSAGE_SIZE);
// GRPC KeepAlive: https://github.com/grpc/grpc/blob/master/doc/keepalive.md
channel_args.SetInt(
GRPC_ARG_KEEPALIVE_TIME_MS, keepalive_options.keepalive_time_ms);
channel_args.SetInt(
GRPC_ARG_KEEPALIVE_TIMEOUT_MS, keepalive_options.keepalive_timeout_ms);
channel_args.SetInt(
GRPC_ARG_KEEPALIVE_PERMIT_WITHOUT_CALLS,
keepalive_options.keepalive_permit_without_calls);
channel_args.SetInt(
GRPC_ARG_HTTP2_MAX_PINGS_WITHOUT_DATA,
keepalive_options.http2_max_pings_without_data);
client->reset(new InferenceServerGrpcClient(
server_url, verbose, use_ssl, ssl_options, channel_args,
use_cached_channel));
return Error::Success;
}
Error
InferenceServerGrpcClient::IsServerLive(bool* live, const Headers& headers)
{
Error err;
inference::ServerLiveRequest request;
inference::ServerLiveResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
grpc::Status grpc_status = stub_->ServerLive(&context, request, &response);
if (grpc_status.ok()) {
*live = response.live();
if (verbose_) {
std::cout << "Server Live : " << *live << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::IsServerReady(bool* ready, const Headers& headers)
{
Error err;
inference::ServerReadyRequest request;
inference::ServerReadyResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
grpc::Status grpc_status = stub_->ServerReady(&context, request, &response);
if (grpc_status.ok()) {
*ready = response.ready();
if (verbose_) {
std::cout << "Server Ready : " << *ready << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::IsModelReady(
bool* ready, const std::string& model_name,
const std::string& model_version, const Headers& headers)
{
Error err;
inference::ModelReadyRequest request;
inference::ModelReadyResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(model_name);
request.set_version(model_version);
grpc::Status grpc_status = stub_->ModelReady(&context, request, &response);
if (grpc_status.ok()) {
*ready = response.ready();
if (verbose_) {
std::cout << "Model Ready : name: " << model_name;
if (!model_version.empty()) {
std::cout << "(version: " << model_version << ") ";
}
std::cout << ": " << *ready << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::ServerMetadata(
inference::ServerMetadataResponse* server_metadata, const Headers& headers)
{
server_metadata->Clear();
Error err;
inference::ServerMetadataRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
grpc::Status grpc_status =
stub_->ServerMetadata(&context, request, server_metadata);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << server_metadata->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::ModelMetadata(
inference::ModelMetadataResponse* model_metadata,
const std::string& model_name, const std::string& model_version,
const Headers& headers)
{
model_metadata->Clear();
Error err;
inference::ModelMetadataRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(model_name);
request.set_version(model_version);
grpc::Status grpc_status =
stub_->ModelMetadata(&context, request, model_metadata);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << model_metadata->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::ModelConfig(
inference::ModelConfigResponse* model_config, const std::string& model_name,
const std::string& model_version, const Headers& headers)
{
model_config->Clear();
Error err;
inference::ModelConfigRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(model_name);
request.set_version(model_version);
grpc::Status grpc_status =
stub_->ModelConfig(&context, request, model_config);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << model_config->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::ModelRepositoryIndex(
inference::RepositoryIndexResponse* repository_index,
const Headers& headers)
{
repository_index->Clear();
Error err;
inference::RepositoryIndexRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
grpc::Status grpc_status =
stub_->RepositoryIndex(&context, request, repository_index);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << repository_index->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::LoadModel(
const std::string& model_name, const Headers& headers,
const std::string& config,
const std::map<std::string, std::vector<char>>& files)
{
Error err;
inference::RepositoryModelLoadRequest request;
inference::RepositoryModelLoadResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_model_name(model_name);
if (!config.empty()) {
(*request.mutable_parameters())["config"].set_string_param(config);
}
for (const auto& file : files) {
(*request.mutable_parameters())[file.first].set_bytes_param(
file.second.data(), file.second.size());
}
grpc::Status grpc_status =
stub_->RepositoryModelLoad(&context, request, &response);
if (!grpc_status.ok()) {
err = Error(grpc_status.error_message());
} else {
if (verbose_) {
std::cout << "Loaded model '" << model_name << "'" << std::endl;
}
}
return err;
}
Error
InferenceServerGrpcClient::UnloadModel(
const std::string& model_name, const Headers& headers)
{
Error err;
inference::RepositoryModelUnloadRequest request;
inference::RepositoryModelUnloadResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_model_name(model_name);
grpc::Status grpc_status =
stub_->RepositoryModelUnload(&context, request, &response);
if (!grpc_status.ok()) {
err = Error(grpc_status.error_message());
} else {
if (verbose_) {
std::cout << "Unloaded model '" << model_name << "'" << std::endl;
}
}
return err;
}
Error
InferenceServerGrpcClient::ModelInferenceStatistics(
inference::ModelStatisticsResponse* infer_stat,
const std::string& model_name, const std::string& model_version,
const Headers& headers)
{
infer_stat->Clear();
Error err;
inference::ModelStatisticsRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(model_name);
request.set_version(model_version);
grpc::Status grpc_status =
stub_->ModelStatistics(&context, request, infer_stat);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << infer_stat->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::UpdateTraceSettings(
inference::TraceSettingResponse* response, const std::string& model_name,
const std::map<std::string, std::vector<std::string>>& settings,
const Headers& headers)
{
inference::TraceSettingRequest request;
grpc::ClientContext context;
Error err;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
if (!model_name.empty()) {
request.set_model_name(model_name);
}
if (!settings.empty()) {
for (const auto& pr : settings) {
if (pr.second.empty()) {
(*request.mutable_settings())[pr.first].clear_value();
} else {
for (const auto& v : pr.second) {
(*request.mutable_settings())[pr.first].add_value(v);
}
}
}
}
grpc::Status grpc_status = stub_->TraceSetting(&context, request, response);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << "Update trace settings " << response->DebugString()
<< std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::GetTraceSettings(
inference::TraceSettingResponse* settings, const std::string& model_name,
const Headers& headers)
{
settings->Clear();
Error err;
inference::TraceSettingRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
if (!model_name.empty()) {
request.set_model_name(model_name);
}
grpc::Status grpc_status = stub_->TraceSetting(&context, request, settings);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << settings->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::SystemSharedMemoryStatus(
inference::SystemSharedMemoryStatusResponse* status,
const std::string& region_name, const Headers& headers)
{
status->Clear();
Error err;
inference::SystemSharedMemoryStatusRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(region_name);
grpc::Status grpc_status =
stub_->SystemSharedMemoryStatus(&context, request, status);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << status->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::RegisterSystemSharedMemory(
const std::string& name, const std::string& key, const size_t byte_size,
const size_t offset, const Headers& headers)
{
Error err;
inference::SystemSharedMemoryRegisterRequest request;
inference::SystemSharedMemoryRegisterResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(name);
request.set_key(key);
request.set_offset(offset);
request.set_byte_size(byte_size);
grpc::Status grpc_status =
stub_->SystemSharedMemoryRegister(&context, request, &response);
if (!grpc_status.ok()) {
err = Error(grpc_status.error_message());
} else {
if (verbose_) {
std::cout << "Registered system shared memory with name '" << name << "'"
<< std::endl;
}
}
return err;
}
Error
InferenceServerGrpcClient::UnregisterSystemSharedMemory(
const std::string& name, const Headers& headers)
{
Error err;
inference::SystemSharedMemoryUnregisterRequest request;
inference::SystemSharedMemoryUnregisterResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(name);
grpc::Status grpc_status =
stub_->SystemSharedMemoryUnregister(&context, request, &response);
if (!grpc_status.ok()) {
err = Error(grpc_status.error_message());
} else {
if (verbose_) {
if (name.size() != 0) {
std::cout << "Unregistered system shared memory with name '" << name
<< "'" << std::endl;
} else {
std::cout << "Unregistered all system shared memory regions"
<< std::endl;
}
}
}
return err;
}
Error
InferenceServerGrpcClient::CudaSharedMemoryStatus(
inference::CudaSharedMemoryStatusResponse* status,
const std::string& region_name, const Headers& headers)
{
status->Clear();
Error err;
inference::CudaSharedMemoryStatusRequest request;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(region_name);
grpc::Status grpc_status =
stub_->CudaSharedMemoryStatus(&context, request, status);
if (grpc_status.ok()) {
if (verbose_) {
std::cout << status->DebugString() << std::endl;
}
} else {
err = Error(grpc_status.error_message());
}
return err;
}
Error
InferenceServerGrpcClient::RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& cuda_shm_handle,
const size_t device_id, const size_t byte_size, const Headers& headers)
{
Error err;
inference::CudaSharedMemoryRegisterRequest request;
inference::CudaSharedMemoryRegisterResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(name);
request.set_raw_handle((char*)&cuda_shm_handle, sizeof(cudaIpcMemHandle_t));
request.set_device_id(device_id);
request.set_byte_size(byte_size);
grpc::Status grpc_status =
stub_->CudaSharedMemoryRegister(&context, request, &response);
if (!grpc_status.ok()) {
err = Error(grpc_status.error_message());
} else {
if (verbose_) {
std::cout << "Registered cuda shared memory with name '" << name << "'"
<< std::endl;
}
}
return err;
}
Error
InferenceServerGrpcClient::UnregisterCudaSharedMemory(
const std::string& name, const Headers& headers)
{
Error err;
inference::CudaSharedMemoryUnregisterRequest request;
inference::CudaSharedMemoryUnregisterResponse response;
grpc::ClientContext context;
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
request.set_name(name);
grpc::Status grpc_status =
stub_->CudaSharedMemoryUnregister(&context, request, &response);
if (!grpc_status.ok()) {
err = Error(grpc_status.error_message());
} else {
if (verbose_) {
if (name.size() != 0) {
std::cout << "Unregistered system shared memory with name '" << name
<< "'" << std::endl;
} else {
std::cout << "Unregistered all system shared memory regions"
<< std::endl;
}
}
}
return err;
}
Error
InferenceServerGrpcClient::Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, grpc_compression_algorithm compression_algorithm)
{
Error err;
grpc::ClientContext context;
std::shared_ptr<GrpcInferRequest> sync_request(new GrpcInferRequest());
sync_request->Timer().Reset();
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::REQUEST_START);
// Use send timer to measure time for marshalling infer request
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
for (const auto& it : headers) {
context.AddMetadata(it.first, it.second);
}
if (options.client_timeout_ != 0) {
auto deadline = std::chrono::system_clock::now() +
std::chrono::microseconds(options.client_timeout_);
context.set_deadline(deadline);
}
context.set_compression_algorithm(compression_algorithm);
err = PreRunProcessing(options, inputs, outputs);
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
if (!err.IsOk()) {
return err;
}
sync_request->grpc_response_->Clear();
sync_request->grpc_status_ = stub_->ModelInfer(
&context, infer_request_, sync_request->grpc_response_.get());
if (!sync_request->grpc_status_.ok()) {
err = Error(sync_request->grpc_status_.error_message());
}
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::RECV_START);
InferResultGrpc::Create(result, sync_request->grpc_response_, err);
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::RECV_END);
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::REQUEST_END);
err = UpdateInferStat(sync_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err << std::endl;
}
if (sync_request->grpc_status_.ok()) {
if (verbose_) {
std::cout << sync_request->grpc_response_->DebugString() << std::endl;
}
}
return (*result)->RequestStatus();
}
Error
InferenceServerGrpcClient::AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, grpc_compression_algorithm compression_algorithm)
{
if (callback == nullptr) {
return Error(
"Callback function must be provided along with AsyncInfer() call.");
}
if (!worker_.joinable()) {
worker_ = std::thread(&InferenceServerGrpcClient::AsyncTransfer, this);
}
GrpcInferRequest* async_request;
async_request = new GrpcInferRequest(std::move(callback));
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::REQUEST_START);
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
for (const auto& it : headers) {
async_request->grpc_context_.AddMetadata(it.first, it.second);
}
if (options.client_timeout_ != 0) {
auto deadline = std::chrono::system_clock::now() +
std::chrono::microseconds(options.client_timeout_);
async_request->grpc_context_.set_deadline(deadline);
}
async_request->grpc_context_.set_compression_algorithm(compression_algorithm);
Error err = PreRunProcessing(options, inputs, outputs);
if (!err.IsOk()) {
delete async_request;
return err;
}
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
std::unique_ptr<
grpc::ClientAsyncResponseReader<inference::ModelInferResponse>>
rpc(stub_->PrepareAsyncModelInfer(
&async_request->grpc_context_, infer_request_,
&async_request_completion_queue_));
rpc->StartCall();
rpc->Finish(
async_request->grpc_response_.get(), &async_request->grpc_status_,
(void*)async_request);
if (verbose_) {
std::cout << "Sent request";
if (options.request_id_.size() != 0) {
std::cout << " '" << options.request_id_ << "'";
}
std::cout << std::endl;
}
return Error::Success;
}
Error
InferenceServerGrpcClient::InferMulti(
std::vector<InferResult*>* results,
const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs,
const Headers& headers, grpc_compression_algorithm compression_algorithm)
{
Error err;
// Sanity check
if ((inputs.size() != options.size()) && (options.size() != 1)) {
return Error(
"'options' must either contain 1 element or match size of 'inputs'");
}
if ((inputs.size() != outputs.size()) &&
((outputs.size() != 1) && (outputs.size() != 0))) {
return Error(
"'outputs' must either contain 0/1 element or match size of 'inputs'");
}
int64_t max_option_idx = options.size() - 1;
// value of '-1' means no output is specified
int64_t max_output_idx = outputs.size() - 1;
static std::vector<const InferRequestedOutput*> empty_outputs{};
for (int64_t i = 0; i < (int64_t)inputs.size(); ++i) {
const auto& request_options = options[std::min(max_option_idx, i)];
const auto& request_output = (max_output_idx == -1)
? empty_outputs
: outputs[std::min(max_output_idx, i)];
results->emplace_back();
err = Infer(
&results->back(), request_options, inputs[i], request_output, headers,
compression_algorithm);
if (!err.IsOk()) {
return err;
}
}
return Error::Success;
}
Error
InferenceServerGrpcClient::AsyncInferMulti(
OnMultiCompleteFn callback, const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs,
const Headers& headers, grpc_compression_algorithm compression_algorithm)
{
// Sanity check
if ((inputs.size() != options.size()) && (options.size() != 1)) {
return Error(
"'options' must either contain 1 element or match size of 'inputs'");
}
if ((inputs.size() != outputs.size()) &&
((outputs.size() != 1) && (outputs.size() != 0))) {
return Error(
"'outputs' must either contain 0/1 element or match size of 'inputs'");
}
if (callback == nullptr) {
return Error(
"Callback function must be provided along with AsyncInferMulti() "
"call.");
}
if (!worker_.joinable()) {
worker_ = std::thread(&InferenceServerGrpcClient::AsyncTransfer, this);
}
int64_t max_option_idx = options.size() - 1;
// value of '-1' means no output is specified
int64_t max_output_idx = outputs.size() - 1;
static std::vector<const InferRequestedOutput*> empty_outputs{};
std::shared_ptr<std::atomic<size_t>> response_counter(
new std::atomic<size_t>(inputs.size()));
std::shared_ptr<std::vector<InferResult*>> responses(
new std::vector<InferResult*>(inputs.size()));
for (int64_t i = 0; i < (int64_t)inputs.size(); ++i) {
const auto& request_options = options[std::min(max_option_idx, i)];
const auto& request_output = (max_output_idx == -1)
? empty_outputs
: outputs[std::min(max_output_idx, i)];
OnCompleteFn cb = [response_counter, responses, i,
callback](InferResult* result) {
(*responses)[i] = result;
// last response
if (response_counter->fetch_sub(1) == 1) {
std::vector<InferResult*> results;
results.swap(*responses);
callback(results);
}
};
auto err = AsyncInfer(
cb, request_options, inputs[i], request_output, headers,
compression_algorithm);
if (!err.IsOk()) {
// Create response with error as other requests may be sent and their
// responses may not be accessed outside the callback.
InferResult* err_res;
std::shared_ptr<inference::ModelInferResponse> empty_response(
new inference::ModelInferResponse());
InferResultGrpc::Create(&err_res, empty_response, err);
cb(err_res);
continue;
}
}
return Error::Success;
}
Error
InferenceServerGrpcClient::StartStream(
OnCompleteFn callback, bool enable_stats, uint32_t stream_timeout,
const Headers& headers, grpc_compression_algorithm compression_algorithm)
{
if (stream_worker_.joinable()) {
return Error(
"cannot start another stream with one already running. "
"'InferenceServerClient' supports only a single active "
"stream at a given time.");
}
if (callback == nullptr) {
return Error(
"Callback function must be provided along with StartStream() call.");
}
stream_callback_ = callback;
enable_stream_stats_ = enable_stats;
for (const auto& it : headers) {
grpc_context_.AddMetadata(it.first, it.second);
}
if (stream_timeout != 0) {
auto deadline = std::chrono::system_clock::now() +
std::chrono::microseconds(stream_timeout);
grpc_context_.set_deadline(deadline);
}
grpc_context_.set_compression_algorithm(compression_algorithm);
grpc_stream_ = stub_->ModelStreamInfer(&grpc_context_);
stream_worker_ =
std::thread(&InferenceServerGrpcClient::AsyncStreamTransfer, this);
if (verbose_) {
std::cout << "Started stream..." << std::endl;
}
return Error::Success;
}
Error
InferenceServerGrpcClient::StopStream()
{
if (stream_worker_.joinable()) {
grpc_stream_->WritesDone();
// The reader thread will drain the stream properly
stream_worker_.join();
if (verbose_) {
std::cout << "Stopped stream..." << std::endl;
}
}
return Error::Success;
}
Error
InferenceServerGrpcClient::AsyncStreamInfer(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
std::unique_ptr<RequestTimers> timer;
if (enable_stream_stats_) {
timer.reset(new RequestTimers());
timer->CaptureTimestamp(RequestTimers::Kind::REQUEST_START);
timer->CaptureTimestamp(RequestTimers::Kind::SEND_START);
}
Error err = PreRunProcessing(options, inputs, outputs);
if (!err.IsOk()) {
return err;
}
if (enable_stream_stats_) {
timer->CaptureTimestamp(RequestTimers::Kind::SEND_END);
}
if (enable_stream_stats_) {
std::lock_guard<std::mutex> lock(stream_mutex_);
ongoing_stream_request_timers_.push(std::move(timer));
}
bool ok = grpc_stream_->Write(infer_request_);
if (ok) {
if (verbose_) {
std::cout << "Sent request";
if (options.request_id_.size() != 0) {
std::cout << " '" << options.request_id_ << "'";
}
std::cout << " to the stream" << std::endl;
}
return Error::Success;
} else {
return Error("Stream has been closed.");
}
}
Error
InferenceServerGrpcClient::PreRunProcessing(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
// Populate the request protobuf
infer_request_.set_model_name(options.model_name_);
infer_request_.set_model_version(options.model_version_);
infer_request_.set_id(options.request_id_);
infer_request_.mutable_parameters()->clear();
(*infer_request_.mutable_parameters())["triton_enable_empty_final_response"]
.set_bool_param(options.triton_enable_empty_final_response_);
if ((options.sequence_id_ != 0) || (options.sequence_id_str_ != "")) {
if (options.sequence_id_ != 0) {
(*infer_request_.mutable_parameters())["sequence_id"].set_int64_param(
options.sequence_id_);
} else {
(*infer_request_.mutable_parameters())["sequence_id"].set_string_param(
options.sequence_id_str_);
}
(*infer_request_.mutable_parameters())["sequence_start"].set_bool_param(
options.sequence_start_);
(*infer_request_.mutable_parameters())["sequence_end"].set_bool_param(
options.sequence_end_);
}
if (options.priority_ != 0) {
(*infer_request_.mutable_parameters())["priority"].set_uint64_param(
options.priority_);
}
if (options.server_timeout_ != 0) {
(*infer_request_.mutable_parameters())["timeout"].set_int64_param(
options.server_timeout_);
}
int index = 0;
infer_request_.mutable_raw_input_contents()->Clear();
for (const auto input : inputs) {
// Add new InferInputTensor submessages only if required, otherwise
// reuse the submessages already available.
auto grpc_input = (infer_request_.inputs().size() <= index)
? infer_request_.add_inputs()
: infer_request_.mutable_inputs()->Mutable(index);
if (input->IsSharedMemory()) {
// The input contents must be cleared when using shared memory.
grpc_input->Clear();
}
grpc_input->set_name(input->Name());
grpc_input->mutable_shape()->Clear();
for (const auto dim : input->Shape()) {
grpc_input->mutable_shape()->Add(dim);
}
grpc_input->set_datatype(input->Datatype());
input->PrepareForRequest();
grpc_input->mutable_parameters()->clear();
if (input->IsSharedMemory()) {
std::string region_name;
size_t offset;
size_t byte_size;
input->SharedMemoryInfo(&region_name, &byte_size, &offset);
(*grpc_input->mutable_parameters())["shared_memory_region"]
.set_string_param(region_name);
(*grpc_input->mutable_parameters())["shared_memory_byte_size"]
.set_int64_param(byte_size);
if (offset != 0) {
(*grpc_input->mutable_parameters())["shared_memory_offset"]
.set_int64_param(offset);
}
} else {
bool end_of_input = false;
std::string* raw_contents = infer_request_.add_raw_input_contents();
size_t content_size;
input->ByteSize(&content_size);
raw_contents->reserve(content_size);
raw_contents->clear();
while (!end_of_input) {
const uint8_t* buf;
size_t buf_size;
input->GetNext(&buf, &buf_size, &end_of_input);
if (buf != nullptr) {
raw_contents->append(reinterpret_cast<const char*>(buf), buf_size);
}
}
}
index++;
}
// Remove extra InferInputTensor submessages, that are not required for
// this request.
while (index < infer_request_.inputs().size()) {
infer_request_.mutable_inputs()->RemoveLast();
}
index = 0;
for (const auto routput : outputs) {
// Add new InferRequestedOutputTensor submessage only if required, otherwise
// reuse the submessages already available.
auto grpc_output = (infer_request_.outputs().size() <= index)
? infer_request_.add_outputs()
: infer_request_.mutable_outputs()->Mutable(index);
grpc_output->Clear();
grpc_output->set_name(routput->Name());
size_t class_count = routput->ClassificationCount();
if (class_count != 0) {
(*grpc_output->mutable_parameters())["classification"].set_int64_param(
class_count);
}
if (routput->IsSharedMemory()) {
std::string region_name;
size_t offset;
size_t byte_size;
routput->SharedMemoryInfo(&region_name, &byte_size, &offset);
(*grpc_output->mutable_parameters())["shared_memory_region"]
.set_string_param(region_name);
(*grpc_output->mutable_parameters())["shared_memory_byte_size"]
.set_int64_param(byte_size);
if (offset != 0) {
(*grpc_output->mutable_parameters())["shared_memory_offset"]
.set_int64_param(offset);
}
}
index++;
}
// Remove extra InferRequestedOutputTensor submessages, that are not required
// for this request.
while (index < infer_request_.outputs().size()) {
infer_request_.mutable_outputs()->RemoveLast();
}
if (infer_request_.ByteSizeLong() > INT_MAX) {
size_t request_size = infer_request_.ByteSizeLong();
infer_request_.Clear();
return Error(
"Request has byte size " + std::to_string(request_size) +
" which exceed gRPC's byte size limit " + std::to_string(INT_MAX) +
".");
}
return Error::Success;
}
void
InferenceServerGrpcClient::AsyncTransfer()
{
while (!exiting_) {
// GRPC async APIs are thread-safe https://github.com/grpc/grpc/issues/4486
GrpcInferRequest* raw_async_request;
bool ok = true;
bool status =
async_request_completion_queue_.Next((void**)(&raw_async_request), &ok);
std::shared_ptr<GrpcInferRequest> async_request;
if (!ok) {
fprintf(stderr, "Unexpected not ok on client side.\n");
}
if (!status) {
if (!exiting_) {
fprintf(stderr, "Completion queue is closed.\n");
}
} else if (raw_async_request == nullptr) {
fprintf(stderr, "Unexpected null tag received at client.\n");
} else {
async_request.reset(raw_async_request);
InferResult* async_result;
Error err;
if (!async_request->grpc_status_.ok()) {
err = Error(async_request->grpc_status_.error_message());
}
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::RECV_START);
InferResultGrpc::Create(
&async_result, async_request->grpc_response_, err);
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::RECV_END);
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::REQUEST_END);
err = UpdateInferStat(async_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err << std::endl;
}
if (async_request->grpc_status_.ok()) {
if (verbose_) {
std::cout << async_request->grpc_response_->DebugString()
<< std::endl;
}
}
async_request->callback_(async_result);
}
}
}
void
InferenceServerGrpcClient::AsyncStreamTransfer()
{
std::shared_ptr<inference::ModelStreamInferResponse> response =
std::make_shared<inference::ModelStreamInferResponse>();
// End loop if Read() returns false
// (stream ended and all responses are drained)
while (grpc_stream_->Read(response.get())) {
if (exiting_) {
continue;
}
std::unique_ptr<RequestTimers> timer;
if (enable_stream_stats_) {
std::lock_guard<std::mutex> lock(stream_mutex_);
if (!ongoing_stream_request_timers_.empty()) {
timer = std::move(ongoing_stream_request_timers_.front());
ongoing_stream_request_timers_.pop();
}
}
InferResult* stream_result;
// FIXME, DLIS-1263 there is no 1:1 mapping between
// requests and response for decoupled streaming case
// hence, this method will record incorrect statistics
// for decoupled case.
if (timer.get() != nullptr) {
timer->CaptureTimestamp(RequestTimers::Kind::RECV_START);
}
InferResultGrpc::Create(&stream_result, response);
if (timer.get() != nullptr) {
timer->CaptureTimestamp(RequestTimers::Kind::RECV_END);
timer->CaptureTimestamp(RequestTimers::Kind::REQUEST_END);
Error err = UpdateInferStat(*timer);
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err << std::endl;
}
}
if (verbose_) {
std::cout << response->DebugString() << std::endl;
}
stream_callback_(stream_result);
response = std::make_shared<inference::ModelStreamInferResponse>();
}
grpc_stream_->Finish();
}
InferenceServerGrpcClient::InferenceServerGrpcClient(
const std::string& url, bool verbose, bool use_ssl,
const SslOptions& ssl_options, const grpc::ChannelArguments& channel_args,
const bool use_cached_channel)
: InferenceServerClient(verbose)
{
stub_ = GetStub(
url, use_ssl, ssl_options, channel_args, use_cached_channel, verbose);
}
InferenceServerGrpcClient::~InferenceServerGrpcClient()
{
exiting_ = true;
// Close complete queue and wait for the worker thread to return
async_request_completion_queue_.Shutdown();
// thread not joinable if AsyncInfer() is not called
// (it is default constructed thread before the first AsyncInfer() call)
if (worker_.joinable()) {
worker_.join();
}
bool has_next = true;
GrpcInferRequest* async_request;
bool ok;
do {
has_next =
async_request_completion_queue_.Next((void**)&async_request, &ok);
if (has_next && async_request != nullptr) {
delete async_request;
}
} while (has_next);
StopStream();
}
//==============================================================================
}} // namespace triton::client
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
/// \file
#include <grpcpp/grpcpp.h>
#include <queue>
#include "common.h"
#include "grpc_service.grpc.pb.h"
#include "ipc.h"
#include "model_config.pb.h"
namespace triton { namespace client {
/// The key-value map type to be included in the request
/// metadata
typedef std::map<std::string, std::string> Headers;
struct SslOptions {
explicit SslOptions() {}
// File containing the PEM encoding of the server root certificates.
// If this parameter is empty, the default roots will be used. The
// default roots can be overridden using the
// GRPC_DEFAULT_SSL_ROOTS_FILE_PATH environment variable pointing
// to a file on the file system containing the roots.
std::string root_certificates;
// File containing the PEM encoding of the client's private key.
// This parameter can be empty if the client does not have a
// private key.
std::string private_key;
// File containing the PEM encoding of the client's certificate chain.
// This parameter can be empty if the client does not have a
// certificate chain.
std::string certificate_chain;
};
// GRPC KeepAlive: https://grpc.github.io/grpc/cpp/md_doc_keepalive.html
struct KeepAliveOptions {
explicit KeepAliveOptions()
: keepalive_time_ms(INT_MAX), keepalive_timeout_ms(20000),
keepalive_permit_without_calls(false), http2_max_pings_without_data(2)
{
}
// The period (in milliseconds) after which a keepalive ping is sent on the
// transport
int keepalive_time_ms;
// The amount of time (in milliseconds) the sender of the keepalive ping waits
// for an acknowledgement. If it does not receive an acknowledgment within
// this time, it will close the connection.
int keepalive_timeout_ms;
// If true, allow keepalive pings to be sent even if there are no calls in
// flight.
bool keepalive_permit_without_calls;
// The maximum number of pings that can be sent when there is no data/header
// frame to be sent. gRPC Core will not continue sending pings if we run over
// the limit. Setting it to 0 allows sending pings without such a restriction.
int http2_max_pings_without_data;
};
//==============================================================================
/// An InferenceServerGrpcClient object is used to perform any kind of
/// communication with the InferenceServer using gRPC protocol. Most
/// of the methods are thread-safe except Infer, AsyncInfer, StartStream
/// StopStream and AsyncStreamInfer. Calling these functions from different
/// threads will cause undefined behavior.
///
/// \code
/// std::unique_ptr<InferenceServerGrpcClient> client;
/// InferenceServerGrpcClient::Create(&client, "localhost:8001");
/// bool live;
/// client->IsServerLive(&live);
/// ...
/// ...
/// \endcode
///
class InferenceServerGrpcClient : public InferenceServerClient {
public:
~InferenceServerGrpcClient();
/// Create a client that can be used to communicate with the server.
/// This is the expected method for most users to create a GRPC client with
/// the options directly exposed Triton.
/// \param client Returns a new InferenceServerGrpcClient object.
/// \param server_url The inference server name and port.
/// \param verbose If true generate verbose output when contacting
/// the inference server.
/// \param use_ssl If true use encrypted channel to the server.
/// \param ssl_options Specifies the files required for
/// SSL encryption and authorization.
/// \param keepalive_options Specifies the GRPC KeepAlive options described
/// in https://grpc.github.io/grpc/cpp/md_doc_keepalive.html
/// \param use_cached_channel If false, a new channel is created for each
/// new client instance. When true, re-use old channels from cache for new
/// client instances. The default value is true.
/// \return Error object indicating success or failure.
static Error Create(
std::unique_ptr<InferenceServerGrpcClient>* client,
const std::string& server_url, bool verbose = false, bool use_ssl = false,
const SslOptions& ssl_options = SslOptions(),
const KeepAliveOptions& keepalive_options = KeepAliveOptions(),
const bool use_cached_channel = true);
/// Create a client that can be used to communicate with the server.
/// This method is available for advanced users who need to specify custom
/// grpc::ChannelArguments not exposed by Triton, at their own risk.
/// \param client Returns a new InferenceServerGrpcClient object.
/// \param channel_args Exposes user-defined grpc::ChannelArguments to
/// be set for the client. Triton assumes that the "channel_args" passed
/// to this method are correct and complete, and are set at the user's
/// own risk. For example, GRPC KeepAlive options may be specified directly
/// in this argument rather than passing a KeepAliveOptions object.
/// \param server_url The inference server name and port.
/// \param verbose If true generate verbose output when contacting
/// the inference server.
/// \param use_ssl If true use encrypted channel to the server.
/// \param ssl_options Specifies the files required for
/// SSL encryption and authorization.
/// \param use_cached_channel If false, a new channel is created for each
/// new client instance. When true, re-use old channels from cache for new
/// client instances. The default value is true.
/// \return Error object indicating success or failure.
static Error Create(
std::unique_ptr<InferenceServerGrpcClient>* client,
const std::string& server_url, const grpc::ChannelArguments& channel_args,
bool verbose = false, bool use_ssl = false,
const SslOptions& ssl_options = SslOptions(),
const bool use_cached_channel = true);
/// Contact the inference server and get its liveness.
/// \param live Returns whether the server is live or not.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error IsServerLive(bool* live, const Headers& headers = Headers());
/// Contact the inference server and get its readiness.
/// \param ready Returns whether the server is ready or not.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error IsServerReady(bool* ready, const Headers& headers = Headers());
/// Contact the inference server and get the readiness of specified model.
/// \param ready Returns whether the specified model is ready or not.
/// \param model_name The name of the model to check for readiness.
/// \param model_version The version of the model to check for readiness.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error IsModelReady(
bool* ready, const std::string& model_name,
const std::string& model_version = "",
const Headers& headers = Headers());
/// Contact the inference server and get its metadata.
/// \param server_metadata Returns the server metadata as
/// SeverMetadataResponse message.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error ServerMetadata(
inference::ServerMetadataResponse* server_metadata,
const Headers& headers = Headers());
/// Contact the inference server and get the metadata of specified model.
/// \param model_metadata Returns model metadata as ModelMetadataResponse
/// message.
/// \param model_name The name of the model to get metadata.
/// \param model_version The version of the model to get metadata.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error ModelMetadata(
inference::ModelMetadataResponse* model_metadata,
const std::string& model_name, const std::string& model_version = "",
const Headers& headers = Headers());
/// Contact the inference server and get the configuration of specified model.
/// \param model_config Returns model config as ModelConfigResponse
/// message.
/// \param model_name The name of the model to get configuration.
/// \param model_version The version of the model to get configuration.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error ModelConfig(
inference::ModelConfigResponse* model_config,
const std::string& model_name, const std::string& model_version = "",
const Headers& headers = Headers());
/// Contact the inference server and get the index of model repository
/// contents.
/// \param repository_index Returns the repository index as
/// RepositoryIndexRequestResponse
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error ModelRepositoryIndex(
inference::RepositoryIndexResponse* repository_index,
const Headers& headers = Headers());
/// Request the inference server to load or reload specified model.
/// \param model_name The name of the model to be loaded or reloaded.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \param config Optional JSON representation of a model config provided for
/// the load request, if provided, this config will be used for
/// loading the model.
/// \param files Optional map specifying file path (with "file:"
/// prefix) in the override model directory to the file content.
/// The files will form the model directory that the model
/// will be loaded from. If specified, 'config' must be provided to be
/// the model configuration of the override model directory.
/// \return Error object indicating success or failure of the request.
Error LoadModel(
const std::string& model_name, const Headers& headers = Headers(),
const std::string& config = std::string(),
const std::map<std::string, std::vector<char>>& files = {});
/// Request the inference server to unload specified model.
/// \param model_name The name of the model to be unloaded.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error UnloadModel(
const std::string& model_name, const Headers& headers = Headers());
/// Contact the inference server and get the inference statistics for the
/// specified model name and version.
/// \param infer_stat The inference statistics of requested model name and
/// version.
/// \param model_name The name of the model to get inference statistics. The
/// default value is an empty string which means statistics of all models will
/// be returned in the response.
/// \param model_version The version of the model to get inference statistics.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error ModelInferenceStatistics(
inference::ModelStatisticsResponse* infer_stat,
const std::string& model_name = "", const std::string& model_version = "",
const Headers& headers = Headers());
/// Update the trace settings for the specified model name, or global trace
/// settings if model name is not given.
/// \param response The updated settings as TraceSettingResponse.
/// \param model_name The name of the model to update trace settings. The
/// default value is an empty string which means the global trace settings
/// will be updated.
/// \param settings The new trace setting values. Only the settings listed
/// will be updated. If a trace setting is listed in the map with an empty
/// string, that setting will be cleared.
/// \param config Optional JSON representation of a model config provided for
/// the load request, if provided, this config will be used for
/// loading the model.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error UpdateTraceSettings(
inference::TraceSettingResponse* response,
const std::string& model_name = "",
const std::map<std::string, std::vector<std::string>>& settings =
std::map<std::string, std::vector<std::string>>(),
const Headers& headers = Headers());
/// Get the trace settings for the specified model name, or global trace
/// settings if model name is not given.
/// \param settings The trace settings as TraceSettingResponse.
/// \param model_name The name of the model to get trace settings. The
/// default value is an empty string which means the global trace settings
/// will be returned.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error GetTraceSettings(
inference::TraceSettingResponse* settings,
const std::string& model_name = "", const Headers& headers = Headers());
/// Contact the inference server and get the status for requested system
/// shared memory.
/// \param status The system shared memory status as
/// SystemSharedMemoryStatusResponse
/// \param region_name The name of the region to query status. The default
/// value is an empty string, which means that the status of all active system
/// shared memory will be returned.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error SystemSharedMemoryStatus(
inference::SystemSharedMemoryStatusResponse* status,
const std::string& region_name = "", const Headers& headers = Headers());
/// Request the server to register a system shared memory with the provided
/// details.
/// \param name The name of the region to register.
/// \param key The key of the underlying memory object that contains the
/// system shared memory region.
/// \param byte_size The size of the system shared memory region, in bytes.
/// \param offset Offset, in bytes, within the underlying memory object to
/// the start of the system shared memory region. The default value is zero.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request
Error RegisterSystemSharedMemory(
const std::string& name, const std::string& key, const size_t byte_size,
const size_t offset = 0, const Headers& headers = Headers());
/// Request the server to unregister a system shared memory with the
/// specified name.
/// \param name The name of the region to unregister. The default value is
/// empty string which means all the system shared memory regions will be
/// unregistered.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request
Error UnregisterSystemSharedMemory(
const std::string& name = "", const Headers& headers = Headers());
/// Contact the inference server and get the status for requested CUDA
/// shared memory.
/// \param status The CUDA shared memory status as
/// CudaSharedMemoryStatusResponse
/// \param region_name The name of the region to query status. The default
/// value is an empty string, which means that the status of all active CUDA
/// shared memory will be returned.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request.
Error CudaSharedMemoryStatus(
inference::CudaSharedMemoryStatusResponse* status,
const std::string& region_name = "", const Headers& headers = Headers());
/// Request the server to register a CUDA shared memory with the provided
/// details.
/// \param name The name of the region to register.
/// \param cuda_shm_handle The cudaIPC handle for the memory object.
/// \param device_id The GPU device ID on which the cudaIPC handle was
/// created.
/// \param byte_size The size of the CUDA shared memory region, in
/// bytes.
/// \param headers Optional map specifying additional HTTP headers to
/// include in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request
Error RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& cuda_shm_handle,
const size_t device_id, const size_t byte_size,
const Headers& headers = Headers());
/// Request the server to unregister a CUDA shared memory with the
/// specified name.
/// \param name The name of the region to unregister. The default value is
/// empty string which means all the CUDA shared memory regions will be
/// unregistered.
/// \param headers Optional map specifying additional HTTP headers to
/// include in the metadata of gRPC request.
/// \return Error object indicating success or failure of the request
Error UnregisterCudaSharedMemory(
const std::string& name = "", const Headers& headers = Headers());
/// Run synchronous inference on server.
/// \param result Returns the result of inference.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs Optional vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the model
/// config will be returned as default settings.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \param compression_algorithm The compression algorithm to be used
/// by gRPC when sending requests. By default compression is not used.
/// \return Error object indicating success or failure of the
/// request.
Error Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>(),
const Headers& headers = Headers(),
grpc_compression_algorithm compression_algorithm = GRPC_COMPRESS_NONE);
/// Run asynchronous inference on server.
/// Once the request is completed, the InferResult pointer will be passed to
/// the provided 'callback' function. Upon the invocation of callback
/// function, the ownership of InferResult object is transferred to the
/// function caller. It is then the caller's choice on either retrieving the
/// results inside the callback function or deferring it to a different thread
/// so that the client is unblocked. In order to prevent memory leak, user
/// must ensure this object gets deleted.
/// \param callback The callback function to be invoked on request completion.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs Optional vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the model
/// config will be returned as default settings.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \param compression_algorithm The compression algorithm to be used
/// by gRPC when sending requests. By default compression is not used.
/// \return Error object indicating success or failure of the request.
Error AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>(),
const Headers& headers = Headers(),
grpc_compression_algorithm compression_algorithm = GRPC_COMPRESS_NONE);
/// Run multiple synchronous inferences on server.
/// \param results Returns the results of the inferences.
/// \param options The options for each inference request, one set of
/// options may be provided and it will be used for all inference requests.
/// \param inputs The vector of InferInput objects describing the model inputs
/// for each inference request.
/// \param outputs Optional vector of InferRequestedOutput objects describing
/// how the output must be returned. If not provided then all the outputs in
/// the model config will be returned as default settings. And one set of
/// outputs may be provided and it will be used for all inference requests.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \param compression_algorithm The compression algorithm to be used
/// by gRPC when sending requests. By default compression is not used.
/// \return Error object indicating success or failure of the
/// request.
Error InferMulti(
std::vector<InferResult*>* results,
const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs =
std::vector<std::vector<const InferRequestedOutput*>>(),
const Headers& headers = Headers(),
grpc_compression_algorithm compression_algorithm = GRPC_COMPRESS_NONE);
/// Run multiple asynchronous inferences on server.
/// Once all the requests are completed, the vector of InferResult pointers
/// will be passed to the provided 'callback' function. Upon the invocation
/// of callback function, the ownership of the InferResult objects are
/// transferred to the function caller. It is then the caller's choice on
/// either retrieving the results inside the callback function or deferring it
/// to a different thread so that the client is unblocked. In order to
/// prevent memory leak, user must ensure these objects get deleted.
/// \param callback The callback function to be invoked on the completion of
/// all requests.
/// \param options The options for each inference request, one set of
/// option may be provided and it will be used for all inference requests.
/// \param inputs The vector of InferInput objects describing the model inputs
/// for each inference request.
/// \param outputs Optional vector of InferRequestedOutput objects describing
/// how the output must be returned. If not provided then all the outputs in
/// the model config will be returned as default settings. And one set of
/// outputs may be provided and it will be used for all inference requests.
/// \param headers Optional map specifying additional HTTP headers to include
/// in the metadata of gRPC request.
/// \param compression_algorithm The compression algorithm to be used
/// by gRPC when sending requests. By default compression is not used.
/// \return Error object indicating success or failure of the request.
Error AsyncInferMulti(
OnMultiCompleteFn callback, const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs =
std::vector<std::vector<const InferRequestedOutput*>>(),
const Headers& headers = Headers(),
grpc_compression_algorithm compression_algorithm = GRPC_COMPRESS_NONE);
/// Starts a grpc bi-directional stream to send streaming inferences.
/// \param callback The callback function to be invoked on receiving a
/// response at the stream.
/// \param enable_stats Indicates whether client library should record the
/// the client-side statistics for inference requests on stream or not.
/// The library does not support client side statistics for decoupled
/// streaming. Set this option false when there is no 1:1 mapping between
/// request and response on the stream.
/// \param stream_timeout Specifies the end-to-end timeout for the streaming
/// connection in microseconds. The default value is 0 which means that
/// there is no limitation on deadline. The stream will be closed once
/// the specified time elapses.
/// \param headers Optional map specifying additional HTTP headers to
/// include in the metadata of gRPC request.
/// \param compression_algorithm The compression algorithm to be used
/// by gRPC when sending requests. By default compression is not used.
/// \return Error object indicating success or failure of the request.
Error StartStream(
OnCompleteFn callback, bool enable_stats = true,
uint32_t stream_timeout = 0, const Headers& headers = Headers(),
grpc_compression_algorithm compression_algorithm = GRPC_COMPRESS_NONE);
/// Stops an active grpc bi-directional stream, if one available.
/// \return Error object indicating success or failure of the request.
Error StopStream();
/// Runs an asynchronous inference over gRPC bi-directional streaming
/// API. A stream must be established with a call to StartStream()
/// before calling this function. All the results will be provided to the
/// callback function provided when starting the stream.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs Optional vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the model
/// config will be returned as default settings.
/// \return Error object indicating success or failure of the request.
Error AsyncStreamInfer(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>());
private:
InferenceServerGrpcClient(
const std::string& url, bool verbose, bool use_ssl,
const SslOptions& ssl_options, const grpc::ChannelArguments& channel_args,
const bool use_cached_channel);
Error PreRunProcessing(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs);
void AsyncTransfer();
void AsyncStreamTransfer();
// The producer-consumer queue used to communicate asynchronously with
// the GRPC runtime.
grpc::CompletionQueue async_request_completion_queue_;
// Required to support the grpc bi-directional streaming API.
InferenceServerClient::OnCompleteFn stream_callback_;
std::thread stream_worker_;
std::shared_ptr<grpc::ClientReaderWriter<
inference::ModelInferRequest, inference::ModelStreamInferResponse>>
grpc_stream_;
grpc::ClientContext grpc_context_;
bool enable_stream_stats_;
std::queue<std::unique_ptr<RequestTimers>> ongoing_stream_request_timers_;
std::mutex stream_mutex_;
// GRPC end point.
std::shared_ptr<inference::GRPCInferenceService::Stub> stub_;
// request for GRPC call, one request object can be used for multiple calls
// since it can be overwritten as soon as the GRPC send finishes.
inference::ModelInferRequest infer_request_;
};
}} // namespace triton::client
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
// Include this first to make sure we are a friend of common classes.
#define TRITON_INFERENCE_SERVER_CLIENT_CLASS InferenceServerHttpClient
#include "http_client.h"
#include <curl/curl.h>
#include <atomic>
#include <climits>
#include <cstdint>
#include <deque>
#include <iostream>
#include <string>
#include <utility>
#include "common.h"
#ifdef TRITON_ENABLE_ZLIB
#include <zlib.h>
#endif
extern "C" {
#include "cencode.h"
}
#define TRITONJSON_STATUSTYPE triton::client::Error
#define TRITONJSON_STATUSRETURN(M) return triton::client::Error(M)
#define TRITONJSON_STATUSSUCCESS triton::client::Error::Success
#include "triton/common/triton_json.h"
#ifdef _WIN32
#define strncasecmp(x, y, z) _strnicmp(x, y, z)
#undef min // NOMINMAX did not resolve std::min compile error
#endif //_WIN32
namespace triton { namespace client {
namespace {
constexpr char kContentLengthHTTPHeader[] = "Content-Length";
//==============================================================================
// Global initialization for libcurl. Libcurl requires global
// initialization before any other threads are created and before any
// curl methods are used. The curl_global static object is used to
// perform this initialization.
class CurlGlobal {
public:
~CurlGlobal();
const Error& Status() const { return err_; }
static const CurlGlobal& Get()
{
static CurlGlobal* curl_global = new CurlGlobal();
return *curl_global;
}
private:
CurlGlobal();
Error err_;
};
CurlGlobal::CurlGlobal() : err_(Error::Success)
{
if (curl_global_init(CURL_GLOBAL_ALL) != 0) {
err_ = Error("global initialization failed");
}
}
CurlGlobal::~CurlGlobal()
{
curl_global_cleanup();
}
class CurlGlobalDestroyer {
public:
~CurlGlobalDestroyer() { delete &CurlGlobal::Get(); }
};
static CurlGlobalDestroyer curl_global_destroyer_;
std::string
GetQueryString(const Headers& query_params)
{
std::string query_string;
bool first = true;
for (const auto& pr : query_params) {
if (first) {
first = false;
} else {
query_string += "&";
}
query_string += pr.first + "=" + pr.second;
}
return query_string;
}
// Encodes the contents of the provided buffer into base64 string. Note the
// string is not guaranteed to be null-terminated. Must rely on the returned
// encoded size to get the right contents.
void
Base64Encode(
const char* raw_ptr, const size_t raw_size, char** encoded_ptr,
int* encoded_size)
{
// Encode the handle object to base64
base64_encodestate es;
base64_init_encodestate(&es);
*encoded_ptr = (char*)malloc(raw_size * 2); /* ~4/3 x raw_size */
*encoded_size = base64_encode_block(raw_ptr, raw_size, *encoded_ptr, &es);
int padding_size = base64_encode_blockend(*encoded_ptr + *encoded_size, &es);
*encoded_size += padding_size;
}
#ifdef TRITON_ENABLE_ZLIB
// libcurl provides automatic decompression, so only implement compression
Error
CompressData(
const InferenceServerHttpClient::CompressionType type,
const std::deque<std::pair<uint8_t*, size_t>>& source,
const size_t source_byte_size,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* compressed_data)
{
// nothing to be compressed
if (source_byte_size == 0) {
return Error("nothing to be compressed");
}
z_stream stream;
stream.zalloc = Z_NULL;
stream.zfree = Z_NULL;
stream.opaque = Z_NULL;
switch (type) {
case InferenceServerHttpClient::CompressionType::GZIP:
if (deflateInit2(
&stream, Z_DEFAULT_COMPRESSION /* level */,
Z_DEFLATED /* method */, 15 | 16 /* windowBits */,
8 /* memLevel */, Z_DEFAULT_STRATEGY /* strategy */) != Z_OK) {
return Error("failed to initialize state for gzip data compression");
}
break;
case InferenceServerHttpClient::CompressionType::DEFLATE: {
if (deflateInit(&stream, Z_DEFAULT_COMPRESSION /* level */) != Z_OK) {
return Error("failed to initialize state for deflate data compression");
}
break;
}
case InferenceServerHttpClient::CompressionType::NONE:
return Error("can't compress data with NONE type");
break;
}
// ensure the internal state are cleaned up on function return
std::unique_ptr<z_stream, decltype(&deflateEnd)> managed_stream(
&stream, deflateEnd);
// Reserve the same size as source for compressed data, it is less likely
// that a negative compression happens.
std::unique_ptr<char[]> current_reserved_space(new char[source_byte_size]);
stream.next_out =
reinterpret_cast<unsigned char*>(current_reserved_space.get());
stream.avail_out = source_byte_size;
// Compress until end of 'source'
for (auto it = source.begin(); it != source.end(); ++it) {
stream.next_in = reinterpret_cast<unsigned char*>(it->first);
stream.avail_in = it->second;
// run deflate() on input until source has been read in
do {
// Need additional buffer
if (stream.avail_out == 0) {
compressed_data->emplace_back(
std::move(current_reserved_space), source_byte_size);
current_reserved_space.reset(new char[source_byte_size]);
stream.next_out =
reinterpret_cast<unsigned char*>(current_reserved_space.get());
stream.avail_out = source_byte_size;
}
auto flush = (std::next(it) == source.end()) ? Z_FINISH : Z_NO_FLUSH;
auto ret = deflate(&stream, flush);
if (ret == Z_STREAM_ERROR) {
return Error(
"encountered inconsistent stream state during compression");
}
} while (stream.avail_out == 0);
}
// Make sure the last buffer is committed
if (current_reserved_space != nullptr) {
compressed_data->emplace_back(
std::move(current_reserved_space), source_byte_size - stream.avail_out);
}
return Error::Success;
}
#else
Error
CompressData(
const InferenceServerHttpClient::CompressionType type,
const std::deque<std::pair<uint8_t*, size_t>>& source,
const size_t source_byte_size,
std::vector<std::pair<std::unique_ptr<char[]>, size_t>>* compressed_data)
{
return Error("Cannot compress data as ZLIB is not included in this build");
}
#endif
Error
ParseSslCertType(
HttpSslOptions::CERTTYPE cert_type, std::string* curl_cert_type)
{
switch (cert_type) {
case HttpSslOptions::CERTTYPE::CERT_PEM:
*curl_cert_type = "PEM";
break;
case HttpSslOptions::CERTTYPE::CERT_DER:
*curl_cert_type = "DER";
break;
default:
return Error(
"unsupported ssl certificate type encountered. Only PEM and DER are "
"supported.");
}
return Error::Success;
}
Error
ParseSslKeyType(HttpSslOptions::KEYTYPE key_type, std::string* curl_key_type)
{
switch (key_type) {
case HttpSslOptions::KEYTYPE::KEY_PEM:
*curl_key_type = "PEM";
break;
case HttpSslOptions::KEYTYPE::KEY_DER:
*curl_key_type = "DER";
break;
default:
return Error(
"unsupported ssl key type encountered. Only PEM and DER are "
"supported.");
}
return Error::Success;
}
Error
SetSSLCurlOptions(CURL** curl, const HttpSslOptions& ssl_options)
{
curl_easy_setopt(*curl, CURLOPT_SSL_VERIFYPEER, ssl_options.verify_peer);
curl_easy_setopt(*curl, CURLOPT_SSL_VERIFYHOST, ssl_options.verify_host);
if (!ssl_options.ca_info.empty()) {
curl_easy_setopt(*curl, CURLOPT_CAINFO, ssl_options.ca_info.c_str());
}
std::string curl_cert_type;
Error err = ParseSslCertType(ssl_options.cert_type, &curl_cert_type);
if (!err.IsOk()) {
return err;
}
curl_easy_setopt(*curl, CURLOPT_SSLCERTTYPE, curl_cert_type.c_str());
if (!ssl_options.cert.empty()) {
curl_easy_setopt(*curl, CURLOPT_SSLCERT, ssl_options.cert.c_str());
}
std::string curl_key_type;
err = ParseSslKeyType(ssl_options.key_type, &curl_key_type);
if (!err.IsOk()) {
return err;
}
curl_easy_setopt(*curl, CURLOPT_SSLKEYTYPE, curl_key_type.c_str());
if (!ssl_options.key.empty()) {
curl_easy_setopt(*curl, CURLOPT_SSLKEY, ssl_options.key.c_str());
}
return Error::Success;
}
} // namespace
//==============================================================================
class HttpInferRequest : public InferRequest {
public:
HttpInferRequest(
InferenceServerClient::OnCompleteFn callback = nullptr,
const bool verbose = false);
~HttpInferRequest();
// Initialize the request for HTTP transfer. */
Error InitializeRequest(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs);
// Adds the input data to be delivered to the server
Error AddInput(uint8_t* buf, size_t byte_size);
// Copy into 'buf' up to 'size' bytes of input data. Return the
// actual amount copied in 'input_bytes'.
Error GetNextInput(uint8_t* buf, size_t size, size_t* input_bytes);
Error CompressInput(const InferenceServerHttpClient::CompressionType type);
private:
friend class InferenceServerHttpClient;
friend class InferResultHttp;
Error PrepareRequestJson(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
triton::common::TritonJson::Value* request_json);
protected:
virtual Error ConvertBinaryInputsToJSON(
InferInput& input, triton::common::TritonJson::Value& data_json) const;
virtual Error ConvertBinaryInputToJSON(
const uint8_t* buf, const size_t buf_size, const std::string& datatype,
triton::common::TritonJson::Value& data_json) const;
private:
// Pointer to the list of the HTTP request header, keep it such that it will
// be valid during the transfer and can be freed once transfer is completed.
struct curl_slist* header_list_;
// HTTP response code for the inference request
long http_code_;
size_t total_input_byte_size_;
triton::common::TritonJson::WriteBuffer request_json_;
// Buffer that accumulates the response body.
std::unique_ptr<std::string> infer_response_buffer_;
// The pointers to the input data.
std::deque<std::pair<uint8_t*, size_t>> data_buffers_;
// Placeholder for the compressed data
std::vector<std::pair<std::unique_ptr<char[]>, size_t>> compressed_data_;
size_t response_json_size_;
};
HttpInferRequest::HttpInferRequest(
InferenceServerClient::OnCompleteFn callback, const bool verbose)
: InferRequest(callback, verbose), header_list_(nullptr),
total_input_byte_size_(0), response_json_size_(0)
{
}
HttpInferRequest::~HttpInferRequest()
{
if (header_list_ != nullptr) {
curl_slist_free_all(header_list_);
header_list_ = nullptr;
}
}
Error
HttpInferRequest::InitializeRequest(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
data_buffers_ = {};
total_input_byte_size_ = 0;
http_code_ = 400;
triton::common::TritonJson::Value request_json(
triton::common::TritonJson::ValueType::OBJECT);
Error err = PrepareRequestJson(options, inputs, outputs, &request_json);
if (!err.IsOk()) {
return err;
}
request_json_.Clear();
request_json.Write(&request_json_);
// Add the buffer holding the json to be delivered first
AddInput((uint8_t*)request_json_.Base(), request_json_.Size());
// Prepare buffer to record the response
infer_response_buffer_.reset(new std::string());
return Error::Success;
}
Error
HttpInferRequest::PrepareRequestJson(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
triton::common::TritonJson::Value* request_json)
{
// Can use string-ref because json is serialized before end of
// 'options', 'inputs' and 'outputs' lifetime.
request_json->AddStringRef(
"id", options.request_id_.c_str(), options.request_id_.size());
if ((options.sequence_id_ != 0) || (options.sequence_id_str_ != "") ||
(options.priority_ != 0) || (options.server_timeout_ != 0) ||
outputs.empty()) {
triton::common::TritonJson::Value parameters_json(
*request_json, triton::common::TritonJson::ValueType::OBJECT);
{
if ((options.sequence_id_ != 0) || (options.sequence_id_str_ != "")) {
if (options.sequence_id_ != 0) {
parameters_json.AddUInt("sequence_id", options.sequence_id_);
} else {
parameters_json.AddString(
"sequence_id", options.sequence_id_str_.c_str(),
options.sequence_id_str_.size());
}
parameters_json.AddBool("sequence_start", options.sequence_start_);
parameters_json.AddBool("sequence_end", options.sequence_end_);
}
if (options.priority_ != 0) {
parameters_json.AddUInt("priority", options.priority_);
}
if (options.server_timeout_ != 0) {
parameters_json.AddUInt("timeout", options.server_timeout_);
}
// If no outputs are provided then set the request parameter
// to return all outputs as binary data.
if (outputs.empty()) {
parameters_json.AddBool("binary_data_output", true);
}
}
request_json->Add("parameters", std::move(parameters_json));
}
if (!inputs.empty()) {
triton::common::TritonJson::Value inputs_json(
*request_json, triton::common::TritonJson::ValueType::ARRAY);
for (const auto io : inputs) {
triton::common::TritonJson::Value io_json(
*request_json, triton::common::TritonJson::ValueType::OBJECT);
io_json.AddStringRef("name", io->Name().c_str(), io->Name().size());
io_json.AddStringRef(
"datatype", io->Datatype().c_str(), io->Datatype().size());
triton::common::TritonJson::Value shape_json(
*request_json, triton::common::TritonJson::ValueType::ARRAY);
for (const auto dim : io->Shape()) {
shape_json.AppendUInt(dim);
}
io_json.Add("shape", std::move(shape_json));
triton::common::TritonJson::Value ioparams_json(
*request_json, triton::common::TritonJson::ValueType::OBJECT);
if (io->IsSharedMemory()) {
std::string region_name;
size_t offset;
size_t byte_size;
Error err = io->SharedMemoryInfo(&region_name, &byte_size, &offset);
if (!err.IsOk()) {
return err;
}
ioparams_json.AddString(
"shared_memory_region", region_name.c_str(), region_name.size());
ioparams_json.AddUInt("shared_memory_byte_size", byte_size);
if (offset != 0) {
ioparams_json.AddUInt("shared_memory_offset", offset);
}
io_json.Add("parameters", std::move(ioparams_json));
} else if (io->BinaryData()) {
size_t byte_size;
Error err = io->ByteSize(&byte_size);
if (!err.IsOk()) {
return err;
}
ioparams_json.AddUInt("binary_data_size", byte_size);
io_json.Add("parameters", std::move(ioparams_json));
} else {
triton::common::TritonJson::Value data_json(
*request_json, triton::common::TritonJson::ValueType::ARRAY);
Error err = ConvertBinaryInputsToJSON(*io, data_json);
if (!err.IsOk()) {
return err;
}
io_json.Add("data", std::move(data_json));
}
inputs_json.Append(std::move(io_json));
}
request_json->Add("inputs", std::move(inputs_json));
}
if (!outputs.empty()) {
triton::common::TritonJson::Value outputs_json(
*request_json, triton::common::TritonJson::ValueType::ARRAY);
for (const auto io : outputs) {
triton::common::TritonJson::Value io_json(
*request_json, triton::common::TritonJson::ValueType::OBJECT);
io_json.AddStringRef("name", io->Name().c_str(), io->Name().size());
triton::common::TritonJson::Value ioparams_json(
*request_json, triton::common::TritonJson::ValueType::OBJECT);
if (io->ClassificationCount() > 0) {
ioparams_json.AddUInt("classification", io->ClassificationCount());
}
if (io->IsSharedMemory()) {
std::string region_name;
size_t offset;
size_t byte_size;
Error err = io->SharedMemoryInfo(&region_name, &byte_size, &offset);
if (!err.IsOk()) {
return err;
}
ioparams_json.AddString(
"shared_memory_region", region_name.c_str(), region_name.size());
ioparams_json.AddUInt("shared_memory_byte_size", byte_size);
if (offset != 0) {
ioparams_json.AddUInt("shared_memory_offset", offset);
}
} else {
ioparams_json.AddBool("binary_data", io->BinaryData());
}
io_json.Add("parameters", std::move(ioparams_json));
outputs_json.Append(std::move(io_json));
}
request_json->Add("outputs", std::move(outputs_json));
}
return Error::Success;
}
Error
HttpInferRequest::ConvertBinaryInputsToJSON(
InferInput& input, triton::common::TritonJson::Value& data_json) const
{
input.PrepareForRequest();
bool end_of_input{false};
while (!end_of_input) {
const uint8_t* buf{nullptr};
size_t buf_size{0};
input.GetNext(&buf, &buf_size, &end_of_input);
size_t element_count{1};
for (size_t i = 1; i < input.Shape().size(); i++) {
element_count *= input.Shape()[i];
}
if (buf != nullptr) {
Error err = ConvertBinaryInputToJSON(
buf, element_count, input.Datatype(), data_json);
if (!err.IsOk()) {
return err;
}
}
}
return Error::Success;
}
Error
HttpInferRequest::ConvertBinaryInputToJSON(
const uint8_t* buf, const size_t element_count, const std::string& datatype,
triton::common::TritonJson::Value& data_json) const
{
if (datatype == "BOOL") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendBool(reinterpret_cast<const bool*>(buf)[i]);
}
} else if (datatype == "UINT8") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendUInt(reinterpret_cast<const uint8_t*>(buf)[i]);
}
} else if (datatype == "UINT16") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendUInt(reinterpret_cast<const uint16_t*>(buf)[i]);
}
} else if (datatype == "UINT32") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendUInt(reinterpret_cast<const uint32_t*>(buf)[i]);
}
} else if (datatype == "UINT64") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendUInt(reinterpret_cast<const uint64_t*>(buf)[i]);
}
} else if (datatype == "INT8") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendInt(reinterpret_cast<const int8_t*>(buf)[i]);
}
} else if (datatype == "INT16") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendInt(reinterpret_cast<const int16_t*>(buf)[i]);
}
} else if (datatype == "INT32") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendInt(reinterpret_cast<const int32_t*>(buf)[i]);
}
} else if (datatype == "INT64") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendInt(reinterpret_cast<const int64_t*>(buf)[i]);
}
} else if (datatype == "FP16") {
return Error(
"datatype '" + datatype +
"' is not supported with JSON. Please use the binary data format");
} else if (datatype == "FP32") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendDouble(reinterpret_cast<const float*>(buf)[i]);
}
} else if (datatype == "FP64") {
for (size_t i = 0; i < element_count; i++) {
data_json.AppendDouble(reinterpret_cast<const double*>(buf)[i]);
}
} else if (datatype == "BYTES") {
size_t offset{0};
for (size_t i = 0; i < element_count; i++) {
const size_t len{*reinterpret_cast<const uint32_t*>(buf + offset)};
data_json.AppendStringRef(
reinterpret_cast<const char*>(buf + offset + sizeof(const uint32_t)),
len);
offset += sizeof(const uint32_t) + len;
}
} else if (datatype == "BF16") {
return Error(
"datatype '" + datatype +
"' is not supported with JSON. Please use the binary data format");
} else {
return Error("datatype '" + datatype + "' is invalid");
}
return Error::Success;
}
Error
HttpInferRequest::AddInput(uint8_t* buf, size_t byte_size)
{
data_buffers_.push_back(std::pair<uint8_t*, size_t>(buf, byte_size));
total_input_byte_size_ += byte_size;
return Error::Success;
}
Error
HttpInferRequest::GetNextInput(uint8_t* buf, size_t size, size_t* input_bytes)
{
*input_bytes = 0;
while (!data_buffers_.empty() && size > 0) {
const size_t csz = std::min(data_buffers_.front().second, size);
if (csz > 0) {
const uint8_t* input_ptr = data_buffers_.front().first;
std::copy(input_ptr, input_ptr + csz, buf);
size -= csz;
buf += csz;
*input_bytes += csz;
data_buffers_.front().first += csz;
data_buffers_.front().second -= csz;
}
if (data_buffers_.front().second == 0) {
data_buffers_.pop_front();
}
}
// Set end timestamp if all inputs have been sent.
if (data_buffers_.empty()) {
Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
}
return Error::Success;
}
Error
HttpInferRequest::CompressInput(
const InferenceServerHttpClient::CompressionType type)
{
auto err = CompressData(
type, data_buffers_, total_input_byte_size_, &compressed_data_);
if (!err.IsOk()) {
return err;
}
data_buffers_.clear();
total_input_byte_size_ = 0;
for (const auto& data : compressed_data_) {
data_buffers_.push_back(std::pair<uint8_t*, size_t>(
reinterpret_cast<uint8_t*>(data.first.get()), data.second));
total_input_byte_size_ += data.second;
}
return Error::Success;
}
//==============================================================================
class InferResultHttp : public InferResult {
public:
static void Create(
InferResult** infer_result,
std::shared_ptr<HttpInferRequest> infer_request);
static Error Create(InferResult** infer_result, const Error err);
Error RequestStatus() const override;
Error ModelName(std::string* name) const override;
Error ModelVersion(std::string* version) const override;
Error Id(std::string* id) const override;
Error Shape(const std::string& output_name, std::vector<int64_t>* shape)
const override;
Error Datatype(
const std::string& output_name, std::string* datatype) const override;
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override;
Error IsFinalResponse(bool* is_final_response) const override;
Error IsNullResponse(bool* is_null_response) const override;
Error StringData(
const std::string& output_name,
std::vector<std::string>* string_result) const override;
std::string DebugString() const override;
private:
InferResultHttp(std::shared_ptr<HttpInferRequest> infer_request);
InferResultHttp(const Error err) : status_(err) {}
protected:
InferResultHttp() {}
~InferResultHttp();
virtual Error ConvertJSONOutputToBinary(
triton::common::TritonJson::Value& data_json, const std::string& datatype,
const uint8_t** buf, size_t* buf_size) const;
private:
std::map<std::string, triton::common::TritonJson::Value>
output_name_to_result_map_;
std::map<std::string, std::pair<const uint8_t*, const size_t>>
output_name_to_buffer_map_;
Error status_;
triton::common::TritonJson::Value response_json_;
std::shared_ptr<HttpInferRequest> infer_request_;
bool binary_data_{true};
bool is_final_response_{true};
bool is_null_response_{false};
};
void
InferResultHttp::Create(
InferResult** infer_result, std::shared_ptr<HttpInferRequest> infer_request)
{
*infer_result =
reinterpret_cast<InferResult*>(new InferResultHttp(infer_request));
}
Error
InferResultHttp::Create(InferResult** infer_result, const Error err)
{
if (err.IsOk()) {
return Error(
"Error is not provided for error reporting override of "
"InferResultHttp::Create()");
}
*infer_result = reinterpret_cast<InferResult*>(new InferResultHttp(err));
return Error::Success;
}
Error
InferResultHttp::ModelName(std::string* name) const
{
if (!status_.IsOk()) {
return status_;
}
const char* name_str;
size_t name_strlen;
Error err =
response_json_.MemberAsString("model_name", &name_str, &name_strlen);
if (!err.IsOk()) {
return Error("model name was not returned in the response");
}
name->assign(name_str, name_strlen);
return Error::Success;
}
Error
InferResultHttp::ModelVersion(std::string* version) const
{
if (!status_.IsOk()) {
return status_;
}
const char* version_str;
size_t version_strlen;
Error err = response_json_.MemberAsString(
"model_version", &version_str, &version_strlen);
if (!err.IsOk()) {
return Error("model version was not returned in the response");
}
version->assign(version_str, version_strlen);
return Error::Success;
}
Error
InferResultHttp::Id(std::string* id) const
{
if (!status_.IsOk()) {
return status_;
}
const char* id_str;
size_t id_strlen;
Error err = response_json_.MemberAsString("id", &id_str, &id_strlen);
if (!err.IsOk()) {
return Error("model id was not returned in the response");
}
id->assign(id_str, id_strlen);
return Error::Success;
}
namespace {
Error
ShapeHelper(
const std::string& result_name,
const triton::common::TritonJson::Value& result_json,
std::vector<int64_t>* shape)
{
triton::common::TritonJson::Value shape_json;
if (!const_cast<triton::common::TritonJson::Value&>(result_json)
.Find("shape", &shape_json)) {
return Error(
"The response does not contain shape for output name " + result_name);
}
for (size_t i = 0; i < shape_json.ArraySize(); i++) {
int64_t dim;
Error err = shape_json.IndexAsInt(i, &dim);
if (!err.IsOk()) {
return err;
}
shape->push_back(dim);
}
return Error::Success;
}
} // namespace
Error
InferResultHttp::Shape(
const std::string& output_name, std::vector<int64_t>* shape) const
{
if (!status_.IsOk()) {
return status_;
}
shape->clear();
auto itr = output_name_to_result_map_.find(output_name);
if (itr == output_name_to_result_map_.end()) {
return Error(
"The response does not contain results for output name " + output_name);
}
return ShapeHelper(output_name, itr->second, shape);
}
Error
InferResultHttp::Datatype(
const std::string& output_name, std::string* datatype) const
{
if (!status_.IsOk()) {
return status_;
}
auto itr = output_name_to_result_map_.find(output_name);
if (itr == output_name_to_result_map_.end()) {
return Error(
"The response does not contain results for output name " + output_name);
}
const char* dtype_str;
size_t dtype_strlen;
Error err = itr->second.MemberAsString("datatype", &dtype_str, &dtype_strlen);
if (!err.IsOk()) {
return Error(
"The response does not contain datatype for output name " +
output_name);
}
datatype->assign(dtype_str, dtype_strlen);
return Error::Success;
}
Error
InferResultHttp::RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const
{
if (!status_.IsOk()) {
return status_;
}
auto itr = output_name_to_buffer_map_.find(output_name);
if (itr != output_name_to_buffer_map_.end()) {
*buf = itr->second.first;
*byte_size = itr->second.second;
} else {
return Error(
"The response does not contain results for output name " + output_name);
}
return Error::Success;
}
Error
InferResultHttp::IsFinalResponse(bool* is_final_response) const
{
if (is_final_response == nullptr) {
return Error("is_final_response cannot be nullptr");
}
*is_final_response = is_final_response_;
return Error::Success;
}
Error
InferResultHttp::IsNullResponse(bool* is_null_response) const
{
if (is_null_response == nullptr) {
return Error("is_null_response cannot be nullptr");
}
*is_null_response = is_null_response_;
return Error::Success;
}
Error
InferResultHttp::StringData(
const std::string& output_name,
std::vector<std::string>* string_result) const
{
if (!status_.IsOk()) {
return status_;
}
std::string datatype;
Error err = Datatype(output_name, &datatype);
if (!err.IsOk()) {
return err;
}
if (datatype.compare("BYTES") != 0) {
return Error(
"This function supports tensors with datatype 'BYTES', requested "
"output tensor '" +
output_name + "' with datatype '" + datatype + "'");
}
const uint8_t* buf;
size_t byte_size;
err = RawData(output_name, &buf, &byte_size);
string_result->clear();
size_t buf_offset = 0;
while (byte_size > buf_offset) {
const uint32_t element_size =
*(reinterpret_cast<const uint32_t*>(buf + buf_offset));
string_result->emplace_back(
reinterpret_cast<const char*>(buf + buf_offset + sizeof(element_size)),
element_size);
buf_offset += (sizeof(element_size) + element_size);
}
return Error::Success;
}
std::string
InferResultHttp::DebugString() const
{
if (!status_.IsOk()) {
return status_.Message();
}
triton::common::TritonJson::WriteBuffer buffer;
Error err = response_json_.Write(&buffer);
if (!err.IsOk()) {
return "<failed>";
}
return buffer.Contents();
}
Error
InferResultHttp::RequestStatus() const
{
return status_;
}
InferResultHttp::InferResultHttp(
std::shared_ptr<HttpInferRequest> infer_request)
: infer_request_(infer_request)
{
size_t offset = infer_request->response_json_size_;
if (infer_request->http_code_ == 499) {
status_ = Error("Deadline Exceeded");
} else {
if (offset != 0) {
if (infer_request->verbose_) {
std::cout << "inference response: "
<< infer_request->infer_response_buffer_->substr(0, offset)
<< std::endl;
}
status_ = response_json_.Parse(
(char*)infer_request->infer_response_buffer_.get()->c_str(), offset);
} else {
if (infer_request->verbose_) {
std::cout << "inference response: "
<< *infer_request->infer_response_buffer_ << std::endl;
}
status_ = response_json_.Parse(
(char*)infer_request->infer_response_buffer_.get()->c_str());
}
}
// There should be a valid JSON response in all cases. Either the
// successful infer response or an error response.
if (status_.IsOk()) {
if (infer_request->http_code_ != 200) {
const char* err_str;
size_t err_strlen;
if (!response_json_.MemberAsString("error", &err_str, &err_strlen)
.IsOk()) {
status_ = Error("inference failed with unknown error");
} else {
status_ = Error(std::string(err_str, err_strlen));
}
} else {
triton::common::TritonJson::Value outputs_json;
if (response_json_.Find("outputs", &outputs_json)) {
for (size_t i = 0; i < outputs_json.ArraySize(); i++) {
triton::common::TritonJson::Value output_json;
status_ = outputs_json.IndexAsObject(i, &output_json);
if (!status_.IsOk()) {
break;
}
const char* name_str;
size_t name_strlen;
status_ = output_json.MemberAsString("name", &name_str, &name_strlen);
if (!status_.IsOk()) {
break;
}
std::string output_name(name_str, name_strlen);
triton::common::TritonJson::Value param_json, data_json;
if (output_json.Find("parameters", &param_json)) {
uint64_t data_size = 0;
status_ = param_json.MemberAsUInt("binary_data_size", &data_size);
if (!status_.IsOk()) {
break;
}
output_name_to_buffer_map_.emplace(
output_name,
std::pair<const uint8_t*, const size_t>(
(uint8_t*)(infer_request->infer_response_buffer_.get()
->c_str()) +
offset,
data_size));
offset += data_size;
} else if (output_json.Find("data", &data_json)) {
binary_data_ = false;
std::string datatype;
status_ = output_json.MemberAsString("datatype", &datatype);
if (!status_.IsOk()) {
break;
}
const uint8_t* buf{nullptr};
size_t buf_size{0};
status_ =
ConvertJSONOutputToBinary(data_json, datatype, &buf, &buf_size);
if (!status_.IsOk()) {
break;
}
output_name_to_buffer_map_.emplace(
output_name,
std::pair<const uint8_t*, const size_t>(buf, buf_size));
}
output_name_to_result_map_[output_name] = std::move(output_json);
}
}
}
}
}
InferResultHttp::~InferResultHttp()
{
if (binary_data_) {
return;
}
for (auto& buf_pair : output_name_to_buffer_map_) {
const uint8_t* buf{buf_pair.second.first};
delete buf;
}
}
Error
InferResultHttp::ConvertJSONOutputToBinary(
triton::common::TritonJson::Value& data_json, const std::string& datatype,
const uint8_t** buf, size_t* buf_size) const
{
const size_t element_count{data_json.ArraySize()};
if (datatype == "BOOL") {
*buf = reinterpret_cast<const uint8_t*>(new bool[element_count]);
*buf_size = sizeof(bool) * element_count;
for (size_t i = 0; i < element_count; i++) {
bool value{false};
data_json.IndexAsBool(i, &value);
const_cast<bool*>(reinterpret_cast<const bool*>(*buf))[i] = value;
}
} else if (datatype == "UINT8") {
*buf = reinterpret_cast<const uint8_t*>(new uint8_t[element_count]);
*buf_size = sizeof(uint8_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
uint64_t value{0};
data_json.IndexAsUInt(i, &value);
const_cast<uint8_t*>(reinterpret_cast<const uint8_t*>(*buf))[i] = value;
}
} else if (datatype == "UINT16") {
*buf = reinterpret_cast<const uint8_t*>(new uint16_t[element_count]);
*buf_size = sizeof(uint16_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
uint64_t value{0};
data_json.IndexAsUInt(i, &value);
const_cast<uint16_t*>(reinterpret_cast<const uint16_t*>(*buf))[i] = value;
}
} else if (datatype == "UINT32") {
*buf = reinterpret_cast<const uint8_t*>(new uint32_t[element_count]);
*buf_size = sizeof(uint32_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
uint64_t value{0};
data_json.IndexAsUInt(i, &value);
const_cast<uint32_t*>(reinterpret_cast<const uint32_t*>(*buf))[i] = value;
}
} else if (datatype == "UINT64") {
*buf = reinterpret_cast<const uint8_t*>(new uint64_t[element_count]);
*buf_size = sizeof(uint64_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
uint64_t value{0};
data_json.IndexAsUInt(i, &value);
const_cast<uint64_t*>(reinterpret_cast<const uint64_t*>(*buf))[i] = value;
}
} else if (datatype == "INT8") {
*buf = reinterpret_cast<const uint8_t*>(new int8_t[element_count]);
*buf_size = sizeof(int8_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
int64_t value{0};
data_json.IndexAsInt(i, &value);
const_cast<int8_t*>(reinterpret_cast<const int8_t*>(*buf))[i] = value;
}
} else if (datatype == "INT16") {
*buf = reinterpret_cast<const uint8_t*>(new int16_t[element_count]);
*buf_size = sizeof(int16_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
int64_t value{0};
data_json.IndexAsInt(i, &value);
const_cast<int16_t*>(reinterpret_cast<const int16_t*>(*buf))[i] = value;
}
} else if (datatype == "INT32") {
*buf = reinterpret_cast<const uint8_t*>(new int32_t[element_count]);
*buf_size = sizeof(int32_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
int64_t value{0};
data_json.IndexAsInt(i, &value);
const_cast<int32_t*>(reinterpret_cast<const int32_t*>(*buf))[i] = value;
}
} else if (datatype == "INT64") {
*buf = reinterpret_cast<const uint8_t*>(new int64_t[element_count]);
*buf_size = sizeof(int64_t) * element_count;
for (size_t i = 0; i < element_count; i++) {
int64_t value{0};
data_json.IndexAsInt(i, &value);
const_cast<int64_t*>(reinterpret_cast<const int64_t*>(*buf))[i] = value;
}
} else if (datatype == "FP16") {
return Error("datatype '" + datatype + "' is not supported with JSON.");
} else if (datatype == "FP32") {
*buf = reinterpret_cast<const uint8_t*>(new float[element_count]);
*buf_size = sizeof(float) * element_count;
for (size_t i = 0; i < element_count; i++) {
double value{0.0};
data_json.IndexAsDouble(i, &value);
const_cast<float*>(reinterpret_cast<const float*>(*buf))[i] = value;
}
} else if (datatype == "FP64") {
*buf = reinterpret_cast<const uint8_t*>(new float[element_count]);
*buf_size = sizeof(double) * element_count;
for (size_t i = 0; i < element_count; i++) {
double value{0.0};
data_json.IndexAsDouble(i, &value);
const_cast<double*>(reinterpret_cast<const double*>(*buf))[i] = value;
}
} else if (datatype == "BYTES") {
size_t total_buf_size{0};
std::vector<std::pair<const char*, size_t>> bytes_pairs{};
bytes_pairs.resize(element_count);
for (size_t i = 0; i < element_count; i++) {
data_json.IndexAsString(i, &bytes_pairs[i].first, &bytes_pairs[i].second);
total_buf_size += sizeof(const uint32_t) + bytes_pairs[i].second;
}
*buf = reinterpret_cast<const uint8_t*>(new uint8_t[total_buf_size]);
*buf_size = total_buf_size;
size_t offset{0};
for (const auto& bytes_pair : bytes_pairs) {
const char* bytes{bytes_pair.first};
size_t bytes_size{bytes_pair.second};
std::memcpy(
const_cast<uint8_t*>(*buf + offset), &bytes_size,
sizeof(const uint32_t));
std::memcpy(
const_cast<uint8_t*>(*buf + offset + sizeof(const uint32_t)), bytes,
bytes_size);
offset += sizeof(const uint32_t) + bytes_size;
}
} else if (datatype == "BF16") {
return Error("datatype '" + datatype + "' is not supported with JSON.");
} else {
return Error("datatype '" + datatype + "' is invalid");
}
return Error::Success;
}
//==============================================================================
Error
InferenceServerHttpClient::GenerateRequestBody(
std::vector<char>* request_body, size_t* header_length,
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
auto infer_request = std::unique_ptr<HttpInferRequest>(
new HttpInferRequest(nullptr /* callback */, false));
// Prepare the request object to provide the data for inference.
Error err = infer_request->InitializeRequest(options, inputs, outputs);
if (!err.IsOk()) {
return err;
}
// Add the buffers holding input tensor data
for (const auto this_input : inputs) {
if (!this_input->IsSharedMemory()) {
this_input->PrepareForRequest();
bool end_of_input = false;
while (!end_of_input) {
const uint8_t* buf;
size_t buf_size;
this_input->GetNext(&buf, &buf_size, &end_of_input);
if (buf != nullptr) {
infer_request->AddInput(const_cast<uint8_t*>(buf), buf_size);
}
}
}
}
*header_length = infer_request->request_json_.Size();
*request_body = std::vector<char>(infer_request->total_input_byte_size_);
size_t remaining_bytes = infer_request->total_input_byte_size_;
size_t actual_copied_bytes = 0;
char* current_pos = request_body->data();
while (true) {
err = infer_request->GetNextInput(
reinterpret_cast<uint8_t*>(current_pos), remaining_bytes,
&actual_copied_bytes);
if (!err.IsOk()) {
return err;
}
if (actual_copied_bytes == remaining_bytes) {
break;
} else {
current_pos += actual_copied_bytes;
remaining_bytes -= actual_copied_bytes;
}
}
return Error::Success;
}
Error
InferenceServerHttpClient::ParseResponseBody(
InferResult** result, const std::vector<char>& response_body,
const size_t header_length)
{
// Result data is actually stored in request object
auto infer_request = std::shared_ptr<HttpInferRequest>(
new HttpInferRequest(nullptr /* callback */, false));
infer_request->http_code_ = 200;
infer_request->response_json_size_ = header_length;
infer_request->infer_response_buffer_.reset(
new std::string(response_body.data(), response_body.size()));
InferResultHttp::Create(result, infer_request);
return Error::Success;
}
Error
InferenceServerHttpClient::Create(
std::unique_ptr<InferenceServerHttpClient>* client,
const std::string& server_url, bool verbose,
const HttpSslOptions& ssl_options)
{
client->reset(
new InferenceServerHttpClient(server_url, verbose, ssl_options));
return Error::Success;
}
InferenceServerHttpClient::InferenceServerHttpClient(
const std::string& url, bool verbose, const HttpSslOptions& ssl_options)
: InferenceServerClient(verbose), url_(url), ssl_options_(ssl_options),
easy_handle_(reinterpret_cast<void*>(curl_easy_init())),
multi_handle_(curl_multi_init())
{
}
InferenceServerHttpClient::~InferenceServerHttpClient()
{
exiting_ = true;
// thread not joinable if AsyncInfer() is not called
// (it is default constructed thread before the first AsyncInfer() call)
if (worker_.joinable()) {
cv_.notify_all();
worker_.join();
}
if (easy_handle_ != nullptr) {
curl_easy_cleanup(reinterpret_cast<CURL*>(easy_handle_));
}
if (multi_handle_ != nullptr) {
for (auto& request : ongoing_async_requests_) {
CURL* easy_handle = reinterpret_cast<CURL*>(request.first);
curl_multi_remove_handle(multi_handle_, easy_handle);
curl_easy_cleanup(easy_handle);
}
curl_multi_cleanup(multi_handle_);
}
}
Error
InferenceServerHttpClient::IsServerLive(
bool* live, const Headers& headers, const Parameters& query_params)
{
Error err;
std::string request_uri(url_ + "/v2/health/live");
long http_code;
std::string response;
err = Get(request_uri, headers, query_params, &response, &http_code);
*live = (http_code == 200) ? true : false;
return err;
}
Error
InferenceServerHttpClient::IsServerReady(
bool* ready, const Headers& headers, const Parameters& query_params)
{
Error err;
std::string request_uri(url_ + "/v2/health/live");
long http_code;
std::string response;
err = Get(request_uri, headers, query_params, &response, &http_code);
*ready = (http_code == 200) ? true : false;
return err;
}
Error
InferenceServerHttpClient::IsModelReady(
bool* ready, const std::string& model_name,
const std::string& model_version, const Headers& headers,
const Parameters& query_params)
{
Error err;
std::string request_uri(url_ + "/v2/models/" + model_name);
if (!model_version.empty()) {
request_uri = request_uri + "/versions/" + model_version;
}
request_uri = request_uri + "/ready";
long http_code;
std::string response;
err = Get(request_uri, headers, query_params, &response, &http_code);
*ready = (http_code == 200) ? true : false;
return err;
}
Error
InferenceServerHttpClient::ServerMetadata(
std::string* server_metadata, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2");
return Get(request_uri, headers, query_params, server_metadata);
}
Error
InferenceServerHttpClient::ModelMetadata(
std::string* model_metadata, const std::string& model_name,
const std::string& model_version, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/models/" + model_name);
if (!model_version.empty()) {
request_uri = request_uri + "/versions/" + model_version;
}
return Get(request_uri, headers, query_params, model_metadata);
}
Error
InferenceServerHttpClient::ModelConfig(
std::string* model_config, const std::string& model_name,
const std::string& model_version, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/models/" + model_name);
if (!model_version.empty()) {
request_uri = request_uri + "/versions/" + model_version;
}
request_uri = request_uri + "/config";
return Get(request_uri, headers, query_params, model_config);
}
Error
InferenceServerHttpClient::ModelRepositoryIndex(
std::string* repository_index, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/repository/index");
std::string request; // empty request body
return Post(request_uri, request, headers, query_params, repository_index);
}
Error
InferenceServerHttpClient::LoadModel(
const std::string& model_name, const Headers& headers,
const Parameters& query_params, const std::string& config,
const std::map<std::string, std::vector<char>>& files)
{
std::string request_uri(
url_ + "/v2/repository/models/" + model_name + "/load");
triton::common::TritonJson::Value request_json(
triton::common::TritonJson::ValueType::OBJECT);
bool has_param = false;
triton::common::TritonJson::Value parameters_json(
request_json, triton::common::TritonJson::ValueType::OBJECT);
if (!config.empty()) {
has_param = true;
parameters_json.AddStringRef("config", config.c_str());
}
for (const auto& file : files) {
// base64 encode the file content for HTTP protocol requirement
// Must free encoded_handle after use to prevent memory leak
char* encoded_handle = nullptr;
int encoded_size;
Base64Encode(
file.second.data(), file.second.size(), &encoded_handle, &encoded_size);
if (encoded_handle == nullptr) {
return Error("Failed to base64 encode the file content");
}
has_param = true;
parameters_json.AddString(file.first.c_str(), encoded_handle, encoded_size);
free(encoded_handle);
}
if (has_param) {
request_json.Add("parameters", std::move(parameters_json));
}
triton::common::TritonJson::WriteBuffer buffer;
Error err = request_json.Write(&buffer);
if (!err.IsOk()) {
return err;
}
std::string response;
return Post(request_uri, buffer.Contents(), headers, query_params, &response);
}
Error
InferenceServerHttpClient::UnloadModel(
const std::string& model_name, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(
url_ + "/v2/repository/models/" + model_name + "/unload");
std::string request; // empty request body
std::string response;
return Post(request_uri, request, headers, query_params, &response);
}
Error
InferenceServerHttpClient::ModelInferenceStatistics(
std::string* infer_stat, const std::string& model_name,
const std::string& model_version, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/models");
if (!model_name.empty()) {
request_uri += "/" + model_name;
}
if (!model_version.empty()) {
request_uri += "/versions/" + model_version;
}
request_uri += "/stats";
return Get(request_uri, headers, query_params, infer_stat);
}
Error
InferenceServerHttpClient::UpdateTraceSettings(
std::string* response, const std::string& model_name,
const std::map<std::string, std::vector<std::string>>& settings,
const Headers& headers, const Parameters& query_params)
{
std::string request_uri(url_ + "/v2");
if (!model_name.empty()) {
request_uri += "/models/" + model_name;
}
request_uri += "/trace/setting";
triton::common::TritonJson::Value request_json(
triton::common::TritonJson::ValueType::OBJECT);
{
for (const auto& pr : settings) {
if (pr.second.empty()) {
request_json.Add(pr.first.c_str(), triton::common::TritonJson::Value());
} else {
if (pr.first == "trace_level") {
triton::common::TritonJson::Value level_json(
triton::common::TritonJson::ValueType::ARRAY);
for (const auto& v : pr.second) {
level_json.AppendStringRef(v.c_str());
}
request_json.Add(pr.first.c_str(), std::move(level_json));
} else {
request_json.AddStringRef(pr.first.c_str(), pr.second[0].c_str());
}
}
}
}
triton::common::TritonJson::WriteBuffer buffer;
Error err = request_json.Write(&buffer);
if (!err.IsOk()) {
return err;
}
return Post(request_uri, buffer.Contents(), headers, query_params, response);
}
Error
InferenceServerHttpClient::GetTraceSettings(
std::string* settings, const std::string& model_name,
const Headers& headers, const Parameters& query_params)
{
std::string request_uri(url_ + "/v2");
if (!model_name.empty()) {
request_uri += "/models/" + model_name;
}
request_uri += "/trace/setting";
return Get(request_uri, headers, query_params, settings);
}
Error
InferenceServerHttpClient::SystemSharedMemoryStatus(
std::string* status, const std::string& name, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/systemsharedmemory");
if (!name.empty()) {
request_uri = request_uri + "/region/" + name;
}
request_uri = request_uri + "/status";
return Get(request_uri, headers, query_params, status);
}
Error
InferenceServerHttpClient::RegisterSystemSharedMemory(
const std::string& name, const std::string& key, const size_t byte_size,
const size_t offset, const Headers& headers, const Parameters& query_params)
{
std::string request_uri(
url_ + "/v2/systemsharedmemory/region/" + name + "/register");
triton::common::TritonJson::Value request_json(
triton::common::TritonJson::ValueType::OBJECT);
{
request_json.AddStringRef("key", key.c_str(), key.size());
request_json.AddUInt("offset", offset);
request_json.AddUInt("byte_size", byte_size);
}
triton::common::TritonJson::WriteBuffer buffer;
Error err = request_json.Write(&buffer);
if (!err.IsOk()) {
return err;
}
std::string response;
return Post(request_uri, buffer.Contents(), headers, query_params, &response);
}
Error
InferenceServerHttpClient::UnregisterSystemSharedMemory(
const std::string& region_name, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/systemsharedmemory");
if (!region_name.empty()) {
request_uri = request_uri + "/region/" + region_name;
}
request_uri = request_uri + "/unregister";
std::string request; // empty request body
std::string response;
return Post(request_uri, request, headers, query_params, &response);
}
Error
InferenceServerHttpClient::CudaSharedMemoryStatus(
std::string* status, const std::string& region_name, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/cudasharedmemory");
if (!region_name.empty()) {
request_uri = request_uri + "/region/" + region_name;
}
request_uri = request_uri + "/status";
return Get(request_uri, headers, query_params, status);
}
Error
InferenceServerHttpClient::RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& raw_handle,
const size_t device_id, const size_t byte_size, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(
url_ + "/v2/cudasharedmemory/region/" + name + "/register");
triton::common::TritonJson::Value request_json(
triton::common::TritonJson::ValueType::OBJECT);
{
triton::common::TritonJson::Value raw_handle_json(
request_json, triton::common::TritonJson::ValueType::OBJECT);
{
// Must free encoded_handle after use to prevent memory leak
char* encoded_handle = nullptr;
int encoded_size;
Base64Encode(
(char*)((void*)&raw_handle), sizeof(cudaIpcMemHandle_t),
&encoded_handle, &encoded_size);
if (encoded_handle == nullptr) {
return Error("Failed to base64 encode the cudaIpcMemHandle_t");
}
raw_handle_json.AddString("b64", encoded_handle, encoded_size);
free(encoded_handle);
}
request_json.Add("raw_handle", std::move(raw_handle_json));
request_json.AddUInt("device_id", device_id);
request_json.AddUInt("byte_size", byte_size);
}
triton::common::TritonJson::WriteBuffer buffer;
Error err = request_json.Write(&buffer);
if (!err.IsOk()) {
return err;
}
std::string response;
return Post(request_uri, buffer.Contents(), headers, query_params, &response);
}
Error
InferenceServerHttpClient::UnregisterCudaSharedMemory(
const std::string& name, const Headers& headers,
const Parameters& query_params)
{
std::string request_uri(url_ + "/v2/cudasharedmemory");
if (!name.empty()) {
request_uri = request_uri + "/region/" + name;
}
request_uri = request_uri + "/unregister";
std::string request; // empty request body
std::string response;
return Post(request_uri, request, headers, query_params, &response);
}
Error
InferenceServerHttpClient::Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, const Parameters& query_params,
const CompressionType request_compression_algorithm,
const CompressionType response_compression_algorithm)
{
Error err;
std::string request_uri(url_ + "/v2/models/" + options.model_name_);
if (!options.model_version_.empty()) {
request_uri = request_uri + "/versions/" + options.model_version_;
}
request_uri = request_uri + "/infer";
std::shared_ptr<HttpInferRequest> sync_request(
new HttpInferRequest(nullptr /* callback */, verbose_));
sync_request->Timer().Reset();
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::REQUEST_START);
if (!CurlGlobal::Get().Status().IsOk()) {
return CurlGlobal::Get().Status();
}
err = PreRunProcessing(
easy_handle_, request_uri, options, inputs, outputs, headers,
query_params, request_compression_algorithm,
response_compression_algorithm, sync_request);
if (!err.IsOk()) {
return err;
}
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
// Set SEND_END when content length is 0 (because
// CURLOPT_READFUNCTION will not be called). In that case, we can't
// measure SEND_END properly (send ends after sending request
// header).
if (sync_request->total_input_byte_size_ == 0) {
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
}
// During this call SEND_END (except in above case), RECV_START, and
// RECV_END will be set.
auto curl_status = curl_easy_perform(easy_handle_);
if (curl_status == CURLE_OPERATION_TIMEDOUT) {
return Error(
"HTTP client failed (Deadline Exceeded): " +
std::string(curl_easy_strerror(curl_status)));
} else if (curl_status != CURLE_OK) {
return Error(
"HTTP client failed: " + std::string(curl_easy_strerror(curl_status)));
} else { // Success
curl_easy_getinfo(
easy_handle_, CURLINFO_RESPONSE_CODE, &sync_request->http_code_);
}
InferResultHttp::Create(result, sync_request);
sync_request->Timer().CaptureTimestamp(RequestTimers::Kind::REQUEST_END);
err = UpdateInferStat(sync_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err << std::endl;
}
err = (*result)->RequestStatus();
return err;
}
Error
InferenceServerHttpClient::AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, const Parameters& query_params,
const CompressionType request_compression_algorithm,
const CompressionType response_compression_algorithm)
{
if (callback == nullptr) {
return Error(
"Callback function must be provided along with AsyncInfer() call.");
}
std::shared_ptr<HttpInferRequest> async_request;
if (!multi_handle_) {
return Error("failed to start HTTP asynchronous client");
} else if (!worker_.joinable()) {
worker_ = std::thread(&InferenceServerHttpClient::AsyncTransfer, this);
}
std::string request_uri(url_ + "/v2/models/" + options.model_name_);
if (!options.model_version_.empty()) {
request_uri = request_uri + "/versions/" + options.model_version_;
}
request_uri = request_uri + "/infer";
HttpInferRequest* raw_async_request =
new HttpInferRequest(std::move(callback), verbose_);
async_request.reset(raw_async_request);
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::REQUEST_START);
CURL* multi_easy_handle = curl_easy_init();
Error err = PreRunProcessing(
reinterpret_cast<void*>(multi_easy_handle), request_uri, options, inputs,
outputs, headers, query_params, request_compression_algorithm,
response_compression_algorithm, async_request);
if (!err.IsOk()) {
curl_easy_cleanup(multi_easy_handle);
return err;
}
{
std::lock_guard<std::mutex> lock(mutex_);
auto insert_result = ongoing_async_requests_.emplace(std::make_pair(
reinterpret_cast<uintptr_t>(multi_easy_handle), async_request));
if (!insert_result.second) {
curl_easy_cleanup(multi_easy_handle);
return Error("Failed to insert new asynchronous request context.");
}
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_START);
if (async_request->total_input_byte_size_ == 0) {
// Set SEND_END here because CURLOPT_READFUNCTION will not be called if
// content length is 0. In that case, we can't measure SEND_END properly
// (send ends after sending request header).
async_request->Timer().CaptureTimestamp(RequestTimers::Kind::SEND_END);
}
curl_multi_add_handle(multi_handle_, multi_easy_handle);
}
cv_.notify_all();
return Error::Success;
}
Error
InferenceServerHttpClient::InferMulti(
std::vector<InferResult*>* results,
const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs,
const Headers& headers, const Parameters& query_params,
const CompressionType request_compression_algorithm,
const CompressionType response_compression_algorithm)
{
Error err;
// Sanity check
if ((inputs.size() != options.size()) && (options.size() != 1)) {
return Error(
"'options' must either contain 1 element or match size of 'inputs'");
}
if ((inputs.size() != outputs.size()) &&
((outputs.size() != 1) && (outputs.size() != 0))) {
return Error(
"'outputs' must either contain 0/1 element or match size of 'inputs'");
}
int64_t max_option_idx = options.size() - 1;
// value of '-1' means no output is specified
int64_t max_output_idx = outputs.size() - 1;
static std::vector<const InferRequestedOutput*> empty_outputs{};
for (int64_t i = 0; i < (int64_t)inputs.size(); ++i) {
const auto& request_options = options[std::min(max_option_idx, i)];
const auto& request_output = (max_output_idx == -1)
? empty_outputs
: outputs[std::min(max_output_idx, i)];
results->emplace_back();
err = Infer(
&results->back(), request_options, inputs[i], request_output, headers,
query_params, request_compression_algorithm,
response_compression_algorithm);
if (!err.IsOk()) {
return err;
}
}
return Error::Success;
}
Error
InferenceServerHttpClient::AsyncInferMulti(
OnMultiCompleteFn callback, const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs,
const Headers& headers, const Parameters& query_params,
const CompressionType request_compression_algorithm,
const CompressionType response_compression_algorithm)
{
// Sanity check
if ((inputs.size() != options.size()) && (options.size() != 1)) {
return Error(
"'options' must either contain 1 element or match size of 'inputs'");
}
if ((inputs.size() != outputs.size()) &&
((outputs.size() != 1) && (outputs.size() != 0))) {
return Error(
"'outputs' must either contain 0/1 element or match size of 'inputs'");
}
if (callback == nullptr) {
return Error(
"Callback function must be provided along with AsyncInfer() call.");
}
int64_t max_option_idx = options.size() - 1;
// value of '-1' means no output is specified
int64_t max_output_idx = outputs.size() - 1;
static std::vector<const InferRequestedOutput*> empty_outputs{};
std::shared_ptr<std::atomic<size_t>> response_counter(
new std::atomic<size_t>(inputs.size()));
std::shared_ptr<std::vector<InferResult*>> responses(
new std::vector<InferResult*>(inputs.size()));
for (int64_t i = 0; i < (int64_t)inputs.size(); ++i) {
const auto& request_options = options[std::min(max_option_idx, i)];
const auto& request_output = (max_output_idx == -1)
? empty_outputs
: outputs[std::min(max_output_idx, i)];
OnCompleteFn cb = [response_counter, responses, i,
callback](InferResult* result) {
(*responses)[i] = result;
// last response
if (response_counter->fetch_sub(1) == 1) {
std::vector<InferResult*> results;
results.swap(*responses);
callback(results);
}
};
auto err = AsyncInfer(
cb, request_options, inputs[i], request_output, headers, query_params,
request_compression_algorithm, response_compression_algorithm);
if (!err.IsOk()) {
// Create response with error as other requests may be sent and their
// responses may not be accessed outside the callback.
InferResult* err_res;
err = InferResultHttp::Create(&err_res, err);
if (!err.IsOk()) {
std::cerr << "Failed to create result for error: " << err.Message()
<< std::endl;
}
cb(err_res);
continue;
}
}
return Error::Success;
}
size_t
InferenceServerHttpClient::InferRequestProvider(
void* contents, size_t size, size_t nmemb, void* userp)
{
HttpInferRequest* request = reinterpret_cast<HttpInferRequest*>(userp);
size_t input_bytes = 0;
Error err = request->GetNextInput(
reinterpret_cast<uint8_t*>(contents), size * nmemb, &input_bytes);
if (!err.IsOk()) {
std::cerr << "RequestProvider: " << err << std::endl;
return CURL_READFUNC_ABORT;
}
return input_bytes;
}
size_t
InferenceServerHttpClient::InferResponseHeaderHandler(
void* contents, size_t size, size_t nmemb, void* userp)
{
HttpInferRequest* request = reinterpret_cast<HttpInferRequest*>(userp);
char* buf = reinterpret_cast<char*>(contents);
size_t byte_size = size * nmemb;
size_t idx = strlen(kInferHeaderContentLengthHTTPHeader);
size_t length_idx = strlen(kContentLengthHTTPHeader);
if ((idx < byte_size) &&
!strncasecmp(buf, kInferHeaderContentLengthHTTPHeader, idx)) {
while ((idx < byte_size) && (buf[idx] != ':')) {
++idx;
}
if (idx < byte_size) {
std::string hdr(buf + idx + 1, byte_size - idx - 1);
request->response_json_size_ = std::stoi(hdr);
}
} else if (
(length_idx < byte_size) &&
!strncasecmp(buf, kContentLengthHTTPHeader, length_idx)) {
while ((length_idx < byte_size) && (buf[length_idx] != ':')) {
++length_idx;
}
if (length_idx < byte_size) {
std::string hdr(buf + length_idx + 1, byte_size - length_idx - 1);
request->infer_response_buffer_->reserve(std::stoi(hdr));
}
}
return byte_size;
}
size_t
InferenceServerHttpClient::InferResponseHandler(
void* contents, size_t size, size_t nmemb, void* userp)
{
HttpInferRequest* request = reinterpret_cast<HttpInferRequest*>(userp);
if (request->Timer().Timestamp(RequestTimers::Kind::RECV_START) == 0) {
request->Timer().CaptureTimestamp(RequestTimers::Kind::RECV_START);
}
char* buf = reinterpret_cast<char*>(contents);
size_t result_bytes = size * nmemb;
request->infer_response_buffer_->append(buf, result_bytes);
// InferResponseHandler may be called multiple times so we overwrite
// RECV_END so that we always have the time of the last.
request->Timer().CaptureTimestamp(RequestTimers::Kind::RECV_END);
return result_bytes;
}
Error
InferenceServerHttpClient::PreRunProcessing(
void* vcurl, std::string& request_uri, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, const Parameters& query_params,
const CompressionType request_compression_algorithm,
const CompressionType response_compression_algorithm,
std::shared_ptr<HttpInferRequest>& http_request)
{
CURL* curl = reinterpret_cast<CURL*>(vcurl);
// Prepare the request object to provide the data for inference.
Error err = http_request->InitializeRequest(options, inputs, outputs);
if (!err.IsOk()) {
return err;
}
// Add the buffers holding input tensor data
bool all_inputs_are_json{true};
for (const auto this_input : inputs) {
if (this_input->BinaryData()) {
all_inputs_are_json = false;
}
if (!this_input->IsSharedMemory() && this_input->BinaryData()) {
this_input->PrepareForRequest();
bool end_of_input = false;
while (!end_of_input) {
const uint8_t* buf;
size_t buf_size;
this_input->GetNext(&buf, &buf_size, &end_of_input);
if (buf != nullptr) {
http_request->AddInput(const_cast<uint8_t*>(buf), buf_size);
}
}
}
}
// Compress data if requested
switch (request_compression_algorithm) {
case CompressionType::NONE:
break;
case CompressionType::DEFLATE:
case CompressionType::GZIP:
#ifdef TRITON_ENABLE_ZLIB
http_request->CompressInput(request_compression_algorithm);
break;
#else
return Error(
"Compression type needs to be CompressionType::NONE since ZLIB is "
"not included in client build");
#endif
}
// Prepare curl
if (!query_params.empty()) {
request_uri = request_uri + "?" + GetQueryString(query_params);
}
curl_easy_setopt(curl, CURLOPT_URL, request_uri.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "libcurl-agent/1.0");
curl_easy_setopt(curl, CURLOPT_POST, 1L);
curl_easy_setopt(curl, CURLOPT_TCP_NODELAY, 1L);
if (options.client_timeout_ != 0) {
uint64_t timeout_ms = (options.client_timeout_ / 1000);
curl_easy_setopt(curl, CURLOPT_TIMEOUT_MS, timeout_ms);
}
if (verbose_) {
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);
}
const long buffer_byte_size = 16 * 1024 * 1024;
curl_easy_setopt(curl, CURLOPT_UPLOAD_BUFFERSIZE, buffer_byte_size);
curl_easy_setopt(curl, CURLOPT_BUFFERSIZE, buffer_byte_size);
// request data provided by InferRequestProvider()
curl_easy_setopt(curl, CURLOPT_READFUNCTION, InferRequestProvider);
curl_easy_setopt(curl, CURLOPT_READDATA, http_request.get());
// response headers handled by InferResponseHeaderHandler()
curl_easy_setopt(curl, CURLOPT_HEADERFUNCTION, InferResponseHeaderHandler);
curl_easy_setopt(curl, CURLOPT_HEADERDATA, http_request.get());
// response data handled by InferResponseHandler()
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, InferResponseHandler);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, http_request.get());
const curl_off_t post_byte_size = http_request->total_input_byte_size_;
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE_LARGE, post_byte_size);
err = SetSSLCurlOptions(&curl, ssl_options_);
if (!err.IsOk()) {
return err;
}
struct curl_slist* list = nullptr;
std::string infer_hdr{
std::string(kInferHeaderContentLengthHTTPHeader) + ": " +
std::to_string(http_request->request_json_.Size())};
list = curl_slist_append(list, infer_hdr.c_str());
list = curl_slist_append(list, "Expect:");
if (all_inputs_are_json) {
list = curl_slist_append(list, "Content-Type: application/json");
} else {
list = curl_slist_append(list, "Content-Type: application/octet-stream");
}
for (const auto& pr : headers) {
std::string hdr = pr.first + ": " + pr.second;
list = curl_slist_append(list, hdr.c_str());
}
// Compress data if requested
switch (request_compression_algorithm) {
case CompressionType::NONE:
break;
case CompressionType::DEFLATE:
list = curl_slist_append(list, "Content-Encoding: deflate");
break;
case CompressionType::GZIP:
list = curl_slist_append(list, "Content-Encoding: gzip");
break;
}
switch (response_compression_algorithm) {
case CompressionType::NONE:
break;
case CompressionType::DEFLATE:
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "deflate");
break;
case CompressionType::GZIP:
curl_easy_setopt(curl, CURLOPT_ACCEPT_ENCODING, "gzip");
break;
}
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, list);
// The list will be freed when the request is destructed
http_request->header_list_ = list;
if (verbose_) {
std::cout << "inference request: " << http_request->request_json_.Contents()
<< std::endl;
}
return Error::Success;
}
void
InferenceServerHttpClient::AsyncTransfer()
{
int place_holder = 0;
CURLMsg* msg = nullptr;
do {
std::vector<std::shared_ptr<HttpInferRequest>> request_list;
// sleep if no work is available
std::unique_lock<std::mutex> lock(mutex_);
cv_.wait(lock, [this] {
if (this->exiting_) {
return true;
}
// wake up if an async request has been generated
return !this->ongoing_async_requests_.empty();
});
CURLMcode mc = curl_multi_perform(multi_handle_, &place_holder);
int numfds;
if (mc == CURLM_OK) {
// Wait for activity. If there are no descriptors in the multi_handle_
// then curl_multi_wait will return immediately
mc = curl_multi_wait(multi_handle_, NULL, 0, INT_MAX, &numfds);
if (mc == CURLM_OK) {
while ((msg = curl_multi_info_read(multi_handle_, &place_holder))) {
uintptr_t identifier = reinterpret_cast<uintptr_t>(msg->easy_handle);
auto itr = ongoing_async_requests_.find(identifier);
// This shouldn't happen
if (itr == ongoing_async_requests_.end()) {
std::cerr
<< "Unexpected error: received completed request that is not "
"in the list of asynchronous requests"
<< std::endl;
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
continue;
}
long http_code = 400;
if (msg->data.result == CURLE_OK) {
curl_easy_getinfo(
msg->easy_handle, CURLINFO_RESPONSE_CODE, &http_code);
} else if (msg->data.result == CURLE_OPERATION_TIMEDOUT) {
http_code = 499;
}
request_list.emplace_back(itr->second);
ongoing_async_requests_.erase(itr);
curl_multi_remove_handle(multi_handle_, msg->easy_handle);
curl_easy_cleanup(msg->easy_handle);
std::shared_ptr<HttpInferRequest> async_request = request_list.back();
async_request->http_code_ = http_code;
if (msg->msg != CURLMSG_DONE) {
// Something wrong happened.
std::cerr << "Unexpected error: received CURLMsg=" << msg->msg
<< std::endl;
} else {
async_request->Timer().CaptureTimestamp(
RequestTimers::Kind::REQUEST_END);
Error err = UpdateInferStat(async_request->Timer());
if (!err.IsOk()) {
std::cerr << "Failed to update context stat: " << err
<< std::endl;
}
}
}
} else {
std::cerr << "Unexpected error: curl_multi failed. Code:" << mc
<< std::endl;
}
} else {
std::cerr << "Unexpected error: curl_multi failed. Code:" << mc
<< std::endl;
}
lock.unlock();
for (auto& this_request : request_list) {
InferResult* result;
InferResultHttp::Create(&result, this_request);
this_request->callback_(result);
}
} while (!exiting_);
}
size_t
InferenceServerHttpClient::ResponseHandler(
void* contents, size_t size, size_t nmemb, void* userp)
{
std::string* response_string = reinterpret_cast<std::string*>(userp);
uint8_t* buf = reinterpret_cast<uint8_t*>(contents);
size_t result_bytes = size * nmemb;
std::copy(buf, buf + result_bytes, std::back_inserter(*response_string));
return result_bytes;
}
namespace {
Error
ParseErrorJson(const std::string& json_str)
{
triton::common::TritonJson::Value json;
Error err = json.Parse(json_str.c_str(), json_str.size());
if (!err.IsOk()) {
return err;
}
const char* errstr;
size_t errlen;
err = json.MemberAsString("error", &errstr, &errlen);
if (!err.IsOk()) {
return err;
}
return Error(std::move(std::string(errstr, errlen)));
}
} // namespace
Error
InferenceServerHttpClient::Get(
std::string& request_uri, const Headers& headers,
const Parameters& query_params, std::string* response, long* http_code)
{
if (!query_params.empty()) {
request_uri = request_uri + "?" + GetQueryString(query_params);
}
if (!CurlGlobal::Get().Status().IsOk()) {
return CurlGlobal::Get().Status();
}
CURL* curl = curl_easy_init();
if (!curl) {
return Error("failed to initialize HTTP client");
}
curl_easy_setopt(curl, CURLOPT_URL, request_uri.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "libcurl-agent/1.0");
curl_easy_setopt(curl, CURLOPT_TCP_NODELAY, 1L);
if (verbose_) {
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);
}
// Response data handled by ResponseHandler()
response->clear();
response->reserve(1024);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, ResponseHandler);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, response);
Error err = SetSSLCurlOptions(&curl, ssl_options_);
if (!err.IsOk()) {
return err;
}
// Add user provided headers...
struct curl_slist* header_list = nullptr;
for (const auto& pr : headers) {
std::string hdr = pr.first + ": " + pr.second;
header_list = curl_slist_append(header_list, hdr.c_str());
}
if (header_list != nullptr) {
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, header_list);
}
CURLcode res = curl_easy_perform(curl);
if (res != CURLE_OK) {
curl_slist_free_all(header_list);
curl_easy_cleanup(curl);
return Error("HTTP client failed: " + std::string(curl_easy_strerror(res)));
}
long lhttp_code;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &lhttp_code);
curl_slist_free_all(header_list);
curl_easy_cleanup(curl);
if (verbose_) {
std::cout << *response << std::endl;
}
// If http code was requested for return, then just return it,
// otherwise flag an error if the http code is not 200.
if (http_code != nullptr) {
*http_code = lhttp_code;
} else if (lhttp_code != 200) {
return ParseErrorJson(*response);
}
return Error::Success;
}
Error
InferenceServerHttpClient::Post(
std::string& request_uri, const std::string& request,
const Headers& headers, const Parameters& query_params,
std::string* response)
{
if (!query_params.empty()) {
request_uri = request_uri + "?" + GetQueryString(query_params);
}
if (!CurlGlobal::Get().Status().IsOk()) {
return CurlGlobal::Get().Status();
}
CURL* curl = curl_easy_init();
if (!curl) {
return Error("failed to initialize HTTP client");
}
curl_easy_setopt(curl, CURLOPT_URL, request_uri.c_str());
curl_easy_setopt(curl, CURLOPT_USERAGENT, "libcurl-agent/1.0");
curl_easy_setopt(curl, CURLOPT_TCP_NODELAY, 1L);
curl_easy_setopt(curl, CURLOPT_POSTFIELDSIZE, request.size());
curl_easy_setopt(curl, CURLOPT_POSTFIELDS, request.c_str());
if (verbose_) {
curl_easy_setopt(curl, CURLOPT_VERBOSE, 1L);
}
// Response data handled by ResponseHandler()
response->clear();
response->reserve(1024);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, ResponseHandler);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, response);
Error err = SetSSLCurlOptions(&curl, ssl_options_);
if (!err.IsOk()) {
return err;
}
// Add user provided headers...
struct curl_slist* header_list = nullptr;
for (const auto& pr : headers) {
std::string hdr = pr.first + ": " + pr.second;
header_list = curl_slist_append(header_list, hdr.c_str());
}
if (header_list != nullptr) {
curl_easy_setopt(curl, CURLOPT_HTTPHEADER, header_list);
}
CURLcode res = curl_easy_perform(curl);
if (res != CURLE_OK) {
curl_slist_free_all(header_list);
curl_easy_cleanup(curl);
return Error("HTTP client failed: " + std::string(curl_easy_strerror(res)));
}
long http_code;
curl_easy_getinfo(curl, CURLINFO_RESPONSE_CODE, &http_code);
curl_slist_free_all(header_list);
curl_easy_cleanup(curl);
if (verbose_) {
std::cout << *response << std::endl;
}
if (http_code != 200) {
return ParseErrorJson(*response);
}
return Error::Success;
}
//==============================================================================
}} // namespace triton::client
// Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
/// \file
#include <map>
#include <memory>
#include "common.h"
#include "ipc.h"
namespace triton { namespace client {
class HttpInferRequest;
/// The key-value map type to be included in the request
/// as custom headers.
typedef std::map<std::string, std::string> Headers;
/// The key-value map type to be included as URL parameters.
typedef std::map<std::string, std::string> Parameters;
// The options for authorizing and authenticating SSL/TLS connections.
struct HttpSslOptions {
enum CERTTYPE { CERT_PEM = 0, CERT_DER = 1 };
enum KEYTYPE {
KEY_PEM = 0,
KEY_DER = 1
// TODO: Support loading private key from crypto engine
// KEY_ENG = 2
};
explicit HttpSslOptions()
: verify_peer(1), verify_host(2), cert_type(CERTTYPE::CERT_PEM),
key_type(KEYTYPE::KEY_PEM)
{
}
// This option determines whether curl verifies the authenticity of the peer's
// certificate. A value of 1 means curl verifies; 0 (zero) means it does not.
// Default value is 1. See here for more details:
// https://curl.se/libcurl/c/CURLOPT_SSL_VERIFYPEER.html
long verify_peer;
// This option determines whether libcurl verifies that the server cert is for
// the server it is known as. The default value for this option is 2 which
// means that certificate must indicate that the server is the server to which
// you meant to connect, or the connection fails. See here for more details:
// https://curl.se/libcurl/c/CURLOPT_SSL_VERIFYHOST.html
long verify_host;
// File holding one or more certificates to verify the peer with. If not
// specified, client will look for the system path where cacert bundle is
// assumed to be stored, as established at build time. See here for more
// information: https://curl.se/libcurl/c/CURLOPT_CAINFO.html
std::string ca_info;
// The format of client certificate. By default it is CERT_PEM. See here for
// more details: https://curl.se/libcurl/c/CURLOPT_SSLCERTTYPE.html
CERTTYPE cert_type;
// The file name of your client certificate. See here for more details:
// https://curl.se/libcurl/c/CURLOPT_SSLCERT.html
std::string cert;
// The format of the private key. By default it is KEY_PEM. See here for more
// details: https://curl.se/libcurl/c/CURLOPT_SSLKEYTYPE.html.
KEYTYPE key_type;
// The private key. See here for more details:
// https://curl.se/libcurl/c/CURLOPT_SSLKEY.html.
std::string key;
};
//==============================================================================
/// An InferenceServerHttpClient object is used to perform any kind of
/// communication with the InferenceServer using HTTP protocol. None
/// of the methods of InferenceServerHttpClient are thread safe. The
/// class is intended to be used by a single thread and simultaneously
/// calling different methods with different threads is not supported
/// and will cause undefined behavior.
///
/// \code
/// std::unique_ptr<InferenceServerHttpClient> client;
/// InferenceServerHttpClient::Create(&client, "localhost:8000");
/// bool live;
/// client->IsServerLive(&live);
/// ...
/// ...
/// \endcode
///
class InferenceServerHttpClient : public InferenceServerClient {
public:
enum class CompressionType { NONE, DEFLATE, GZIP };
~InferenceServerHttpClient();
/// Generate a request body for inference using the supplied 'inputs' and
/// requesting the outputs specified by 'outputs'.
/// \param request_body Returns the generated inference request body
/// \param header_length Returns the length of the inference header.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs The vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the
/// model config will be returned as default settings.
/// \return Error object indicating success or failure of the
/// request.
static Error GenerateRequestBody(
std::vector<char>* request_body, size_t* header_length,
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>());
/// Generate a InferResult object from the given 'response_body'.
/// \param result Returns the generated InferResult object.
/// \param response_body The inference response from the server
/// \param header_length The length of the inference header if the header
/// does not occupy the whole response body. 0 indicates that
/// the whole response body is the inference response header.
/// \return Error object indicating success or failure of the
/// request.
static Error ParseResponseBody(
InferResult** result, const std::vector<char>& response_body,
const size_t header_length = 0);
/// Create a client that can be used to communicate with the server.
/// \param client Returns a new InferenceServerHttpClient object.
/// \param server_url The inference server name, port, optional
/// scheme and optional base path in the following format:
/// <scheme://>host:port/<base-path>.
/// \param verbose If true generate verbose output when contacting
/// the inference server.
/// \param ssl_options Specifies the settings for configuring
/// SSL encryption and authorization. Providing these options
/// do not ensure that SSL/TLS will be used in communication.
/// The use of SSL/TLS depends entirely on the server endpoint.
/// These options will be ignored if the server_url does not
/// expose `https://` scheme.
/// \return Error object indicating success or failure.
static Error Create(
std::unique_ptr<InferenceServerHttpClient>* client,
const std::string& server_url, bool verbose = false,
const HttpSslOptions& ssl_options = HttpSslOptions());
/// Contact the inference server and get its liveness.
/// \param live Returns whether the server is live or not.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \return Error object indicating success or failure of the request.
Error IsServerLive(
bool* live, const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get its readiness.
/// \param ready Returns whether the server is ready or not.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error IsServerReady(
bool* ready, const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get the readiness of specified model.
/// \param ready Returns whether the specified model is ready or not.
/// \param model_name The name of the model to check for readiness.
/// \param model_version The version of the model to check for readiness.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error IsModelReady(
bool* ready, const std::string& model_name,
const std::string& model_version = "", const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get its metadata.
/// \param server_metadata Returns JSON representation of the
/// metadata as a string.
/// \param headers Optional map specifying additional HTTP headers to
/// include in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error ServerMetadata(
std::string* server_metadata, const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get the metadata of specified model.
/// \param model_metadata Returns JSON representation of model
/// metadata as a string.
/// \param model_name The name of the model to get metadata.
/// \param model_version The version of the model to get metadata.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error ModelMetadata(
std::string* model_metadata, const std::string& model_name,
const std::string& model_version = "", const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get the configuration of specified model.
/// \param model_config Returns JSON representation of model
/// configuration as a string.
/// \param model_name The name of the model to get configuration.
/// \param model_version The version of the model to get configuration.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error ModelConfig(
std::string* model_config, const std::string& model_name,
const std::string& model_version = "", const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get the index of model repository
/// contents.
/// \param repository_index Returns JSON representation of the
/// repository index as a string.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error ModelRepositoryIndex(
std::string* repository_index, const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Request the inference server to load or reload specified model.
/// \param model_name The name of the model to be loaded or reloaded.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \param config Optional JSON representation of a model config provided for
/// the load request, if provided, this config will be used for
/// loading the model.
/// \param files Optional map specifying file path (with "file:"
/// prefix) in the override model directory to the file content.
/// The files will form the model directory that the model
/// will be loaded from. If specified, 'config' must be provided to be
/// the model configuration of the override model directory.
/// \return Error object indicating success or failure of the request.
Error LoadModel(
const std::string& model_name, const Headers& headers = Headers(),
const Parameters& query_params = Parameters(),
const std::string& config = std::string(),
const std::map<std::string, std::vector<char>>& files = {});
/// Request the inference server to unload specified model.
/// \param model_name The name of the model to be unloaded.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error UnloadModel(
const std::string& model_name, const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get the inference statistics for the
/// specified model name and version.
/// \param infer_stat Returns the JSON representation of the
/// inference statistics as a string.
/// \param model_name The name of the model to get inference statistics. The
/// default value is an empty string which means statistics of all models will
/// be returned in the response.
/// \param model_version The version of the model to get inference statistics.
/// The default value is an empty string which means then the server will
/// choose a version based on the model and internal policy.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error ModelInferenceStatistics(
std::string* infer_stat, const std::string& model_name = "",
const std::string& model_version = "", const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Update the trace settings for the specified model name, or global trace
/// settings if model name is not given.
/// \param response Returns the JSON representation of the updated trace
/// settings as a string.
/// \param model_name The name of the model to update trace settings. The
/// default value is an empty string which means the global trace settings
/// will be updated.
/// \param settings The new trace setting values. Only the settings listed
/// will be updated. If a trace setting is listed in the map with an empty
/// string, that setting will be cleared.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error UpdateTraceSettings(
std::string* response, const std::string& model_name = "",
const std::map<std::string, std::vector<std::string>>& settings =
std::map<std::string, std::vector<std::string>>(),
const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Get the trace settings for the specified model name, or global trace
/// settings if model name is not given.
/// \param settings Returns the JSON representation of the trace
/// settings as a string.
/// \param model_name The name of the model to get trace settings. The
/// default value is an empty string which means the global trace settings
/// will be returned.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error GetTraceSettings(
std::string* settings, const std::string& model_name = "",
const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get the status for requested system
/// shared memory.
/// \param status Returns the JSON representation of the system
/// shared memory status as a string.
/// \param region_name The name of the region to query status. The default
/// value is an empty string, which means that the status of all active system
/// shared memory will be returned.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error SystemSharedMemoryStatus(
std::string* status, const std::string& region_name = "",
const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Request the server to register a system shared memory with the provided
/// details.
/// \param name The name of the region to register.
/// \param key The key of the underlying memory object that contains the
/// system shared memory region.
/// \param byte_size The size of the system shared memory region, in bytes.
/// \param offset Offset, in bytes, within the underlying memory object to
/// the start of the system shared memory region. The default value is zero.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request
Error RegisterSystemSharedMemory(
const std::string& name, const std::string& key, const size_t byte_size,
const size_t offset = 0, const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Request the server to unregister a system shared memory with the
/// specified name.
/// \param name The name of the region to unregister. The default value is
/// empty string which means all the system shared memory regions will be
/// unregistered.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request
Error UnregisterSystemSharedMemory(
const std::string& name = "", const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Contact the inference server and get the status for requested CUDA
/// shared memory.
/// \param status Returns the JSON representation of the CUDA shared
/// memory status as a string.
/// \param region_name The name of the region to query status. The default
/// value is an empty string, which means that the status of all active CUDA
/// shared memory will be returned.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request.
Error CudaSharedMemoryStatus(
std::string* status, const std::string& region_name = "",
const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Request the server to register a CUDA shared memory with the provided
/// details.
/// \param name The name of the region to register.
/// \param cuda_shm_handle The cudaIPC handle for the memory object.
/// \param device_id The GPU device ID on which the cudaIPC handle was
/// created.
/// \param byte_size The size of the CUDA shared memory region, in
/// bytes.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request
Error RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& cuda_shm_handle,
const size_t device_id, const size_t byte_size,
const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Request the server to unregister a CUDA shared memory with the
/// specified name.
/// \param name The name of the region to unregister. The default value is
/// empty string which means all the CUDA shared memory regions will be
/// unregistered.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \return Error object indicating success or failure of the request
Error UnregisterCudaSharedMemory(
const std::string& name = "", const Headers& headers = Headers(),
const Parameters& query_params = Parameters());
/// Run synchronous inference on server.
/// \param result Returns the result of inference.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs The vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the
/// model config will be returned as default settings.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \param request_compression_algorithm Optional HTTP compression algorithm
/// to use for the request body on client side. Currently supports DEFLATE,
/// GZIP and NONE. By default, no compression is used.
/// \param response_compression_algorithm Optional HTTP compression algorithm
/// to request for the response body. Note that the response may not be
/// compressed if the server does not support the specified algorithm.
/// Currently supports DEFLATE, GZIP and NONE. By default, no compression
/// is used.
/// \return Error object indicating success or failure of the
/// request.
Error Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>(),
const Headers& headers = Headers(),
const Parameters& query_params = Parameters(),
const CompressionType request_compression_algorithm =
CompressionType::NONE,
const CompressionType response_compression_algorithm =
CompressionType::NONE);
/// Run asynchronous inference on server.
/// Once the request is completed, the InferResult pointer will be passed to
/// the provided 'callback' function. Upon the invocation of callback
/// function, the ownership of InferResult object is transferred to the
/// function caller. It is then the caller's choice on either retrieving the
/// results inside the callback function or deferring it to a different thread
/// so that the client is unblocked. In order to prevent memory leak, user
/// must ensure this object gets deleted.
/// Note: InferInput::AppendRaw() or InferInput::SetSharedMemory() calls do
/// not copy the data buffers but hold the pointers to the data directly.
/// It is advisable to not to disturb the buffer contents until the respective
/// callback is invoked.
/// \param callback The callback function to be invoked on request completion.
/// \param options The options for inference request.
/// \param inputs The vector of InferInput describing the model inputs.
/// \param outputs The vector of InferRequestedOutput describing how the
/// output must be returned. If not provided then all the outputs in the
/// model config will be returned as default settings.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \param request_compression_algorithm Optional HTTP compression algorithm
/// to use for the request body on client side. Currently supports DEFLATE,
/// GZIP and NONE. By default, no compression is used.
/// \param response_compression_algorithm Optional HTTP compression algorithm
/// to request for the response body. Note that the response may not be
/// compressed if the server does not support the specified algorithm.
/// Currently supports DEFLATE, GZIP and NONE. By default, no compression
/// is used.
/// \return Error object indicating success
/// or failure of the request.
Error AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs =
std::vector<const InferRequestedOutput*>(),
const Headers& headers = Headers(),
const Parameters& query_params = Parameters(),
const CompressionType request_compression_algorithm =
CompressionType::NONE,
const CompressionType response_compression_algorithm =
CompressionType::NONE);
/// Run multiple synchronous inferences on server.
/// \param results Returns the results of the inferences.
/// \param options The options for each inference request, one set of
/// options may be provided and it will be used for all inference requests.
/// \param inputs The vector of InferInput objects describing the model inputs
/// for each inference request.
/// \param outputs Optional vector of InferRequestedOutput objects describing
/// how the output must be returned. If not provided then all the outputs in
/// the model config will be returned as default settings. And one set of
/// outputs may be provided and it will be used for all inference requests.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \param request_compression_algorithm Optional HTTP compression algorithm
/// to use for the request body on client side. Currently supports DEFLATE,
/// GZIP and NONE. By default, no compression is used.
/// \param response_compression_algorithm Optional HTTP compression algorithm
/// to request for the response body. Note that the response may not be
/// compressed if the server does not support the specified algorithm.
/// Currently supports DEFLATE, GZIP and NONE. By default, no compression
/// is used.
/// \return Error object indicating success or failure of the
/// request.
Error InferMulti(
std::vector<InferResult*>* results,
const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs =
std::vector<std::vector<const InferRequestedOutput*>>(),
const Headers& headers = Headers(),
const Parameters& query_params = Parameters(),
const CompressionType request_compression_algorithm =
CompressionType::NONE,
const CompressionType response_compression_algorithm =
CompressionType::NONE);
/// Run multiple asynchronous inferences on server.
/// Once all the requests are completed, the vector of InferResult pointers
/// will be passed to the provided 'callback' function. Upon the invocation
/// of callback function, the ownership of the InferResult objects are
/// transferred to the function caller. It is then the caller's choice on
/// either retrieving the results inside the callback function or deferring it
/// to a different thread so that the client is unblocked. In order to
/// prevent memory leak, user must ensure these objects get deleted.
/// Note: InferInput::AppendRaw() or InferInput::SetSharedMemory() calls do
/// not copy the data buffers but hold the pointers to the data directly.
/// It is advisable to not to disturb the buffer contents until the respective
/// callback is invoked.
/// \param callback The callback function to be invoked on the completion of
/// all requests.
/// \param options The options for each inference request, one set of
/// option may be provided and it will be used for all inference requests.
/// \param inputs The vector of InferInput objects describing the model inputs
/// for each inference request.
/// \param outputs Optional vector of InferRequestedOutput objects describing
/// how the output must be returned. If not provided then all the outputs in
/// the model config will be returned as default settings. And one set of
/// outputs may be provided and it will be used for all inference requests.
/// \param headers Optional map specifying additional HTTP headers to include
/// in request.
/// \param query_params Optional map specifying parameters that must be
/// included with URL query.
/// \param request_compression_algorithm Optional HTTP compression algorithm
/// to use for the request body on client side. Currently supports DEFLATE,
/// GZIP and NONE. By default, no compression is used.
/// \param response_compression_algorithm Optional HTTP compression algorithm
/// to request for the response body. Note that the response may not be
/// compressed if the server does not support the specified algorithm.
/// Currently supports DEFLATE, GZIP and NONE. By default, no compression
/// is used.
/// \return Error object indicating success
/// or failure of the request.
Error AsyncInferMulti(
OnMultiCompleteFn callback, const std::vector<InferOptions>& options,
const std::vector<std::vector<InferInput*>>& inputs,
const std::vector<std::vector<const InferRequestedOutput*>>& outputs =
std::vector<std::vector<const InferRequestedOutput*>>(),
const Headers& headers = Headers(),
const Parameters& query_params = Parameters(),
const CompressionType request_compression_algorithm =
CompressionType::NONE,
const CompressionType response_compression_algorithm =
CompressionType::NONE);
private:
InferenceServerHttpClient(
const std::string& url, bool verbose, const HttpSslOptions& ssl_options);
Error PreRunProcessing(
void* curl, std::string& request_uri, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs,
const Headers& headers, const Parameters& query_params,
const CompressionType request_compression_algorithm,
const CompressionType response_compression_algorithm,
std::shared_ptr<HttpInferRequest>& request);
void AsyncTransfer();
Error Get(
std::string& request_uri, const Headers& headers,
const Parameters& query_params, std::string* response,
long* http_code = nullptr);
Error Post(
std::string& request_uri, const std::string& request,
const Headers& headers, const Parameters& query_params,
std::string* response);
static size_t ResponseHandler(
void* contents, size_t size, size_t nmemb, void* userp);
static size_t InferRequestProvider(
void* contents, size_t size, size_t nmemb, void* userp);
static size_t InferResponseHeaderHandler(
void* contents, size_t size, size_t nmemb, void* userp);
static size_t InferResponseHandler(
void* contents, size_t size, size_t nmemb, void* userp);
// The server url
const std::string url_;
// The options for authorizing and authenticating SSL/TLS connections
HttpSslOptions ssl_options_;
using AsyncReqMap = std::map<uintptr_t, std::shared_ptr<HttpInferRequest>>;
// curl easy handle shared for all synchronous requests
void* easy_handle_;
// curl multi handle for processing asynchronous requests
void* multi_handle_;
// map to record ongoing asynchronous requests with pointer to easy handle
// or tag id as key
AsyncReqMap ongoing_async_requests_;
};
}} // namespace triton::client
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#ifdef TRITON_ENABLE_GPU
#include <cuda_runtime_api.h>
#else
struct cudaIpcMemHandle_t {};
#endif // TRITON_ENABLE_GPU
// Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "json_utils.h"
#include <rapidjson/error/en.h>
namespace triton { namespace client {
Error
ParseJson(rapidjson::Document* document, const std::string& json_str)
{
const unsigned int parseFlags = rapidjson::kParseNanAndInfFlag;
document->Parse<parseFlags>(json_str.c_str(), json_str.size());
if (document->HasParseError()) {
return Error(
"failed to parse JSON at" + std::to_string(document->GetErrorOffset()) +
": " + std::string(GetParseError_En(document->GetParseError())));
}
return Error::Success;
}
}} // namespace triton::client
// Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <rapidjson/document.h>
#include <rapidjson/rapidjson.h>
#include <string>
#include "common.h"
namespace triton { namespace client {
Error ParseJson(rapidjson::Document* document, const std::string& json_str);
}} // namespace triton::client
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
extern "C++" {
triton::client*;
inference*;
};
local: *;
};
# Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
{
global:
extern "C++" {
triton::client*
};
local: *;
};
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "shm_utils.h"
#include <fcntl.h>
#include <sys/mman.h>
#include <unistd.h>
#include <iostream>
#include <string>
namespace triton { namespace client {
Error
CreateSharedMemoryRegion(std::string shm_key, size_t byte_size, int* shm_fd)
{
// get shared memory region descriptor
*shm_fd = shm_open(shm_key.c_str(), O_RDWR | O_CREAT, S_IRUSR | S_IWUSR);
if (*shm_fd == -1) {
return Error(
"unable to get shared memory descriptor for shared-memory key '" +
shm_key + "'");
}
// extend shared memory object as by default it's initialized with size 0
int res = ftruncate(*shm_fd, byte_size);
if (res == -1) {
return Error(
"unable to initialize shared-memory key '" + shm_key +
"' to requested size: " + std::to_string(byte_size) + " bytes");
}
return Error::Success;
}
Error
MapSharedMemory(int shm_fd, size_t offset, size_t byte_size, void** shm_addr)
{
// map shared memory to process address space
*shm_addr =
mmap(NULL, byte_size, PROT_READ | PROT_WRITE, MAP_SHARED, shm_fd, offset);
if (*shm_addr == MAP_FAILED) {
return Error(
"unable to process address space or shared-memory descriptor: " +
std::to_string(shm_fd));
}
return Error::Success;
}
Error
CloseSharedMemory(int shm_fd)
{
// close shared memory descriptor
if (close(shm_fd) == -1) {
return Error(
"unable to close shared-memory descriptor: " + std::to_string(shm_fd));
}
return Error::Success;
}
Error
UnlinkSharedMemoryRegion(std::string shm_key)
{
int shm_fd = shm_unlink(shm_key.c_str());
if (shm_fd == -1) {
return Error("unable to unlink shared memory for key '" + shm_key + "'");
}
return Error::Success;
}
Error
UnmapSharedMemory(void* shm_addr, size_t byte_size)
{
int tmp_fd = munmap(shm_addr, byte_size);
if (tmp_fd == -1) {
return Error("unable to munmap shared memory region");
}
return Error::Success;
}
}} // namespace triton::client
// Copyright (c) 2020-2021, NVIDIA CORPORATION. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include "common.h"
namespace triton { namespace client {
// Create a shared memory region of the size 'byte_size' and return the unique
// identifier.
// \param shm_key The string identifier of the shared memory region
// \param byte_size The size in bytes of the shared memory region
// \param shm_fd Returns an int descriptor of the created shared memory region
// \return error Returns an error if unable to open shared memory region.
Error CreateSharedMemoryRegion(
std::string shm_key, size_t byte_size, int* shm_fd);
// Mmap the shared memory region with the given 'offset' and 'byte_size' and
// return the base address of the region.
// \param shm_fd The int descriptor of the created shared memory region
// \param offset The offset of the shared memory block from the start of the
// shared memory region
// \param byte_size The size in bytes of the shared memory region
// \param shm_addr Returns the base address of the shared memory region
// \return error Returns an error if unable to mmap shared memory region.
Error MapSharedMemory(
int shm_fd, size_t offset, size_t byte_size, void** shm_addr);
// Close the shared memory descriptor.
// \param shm_fd The int descriptor of the created shared memory region
// \return error Returns an error if unable to close shared memory descriptor.
Error CloseSharedMemory(int shm_fd);
// Destroy the shared memory region with the given name.
// \return error Returns an error if unable to unlink shared memory region.
Error UnlinkSharedMemoryRegion(std::string shm_key);
// Munmap the shared memory region from the base address with the given
// byte_size.
// \return error Returns an error if unable to unmap shared memory region.
Error UnmapSharedMemory(void* shm_addr, size_t byte_size);
}} // namespace triton::client
# Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18)
if(WIN32)
message("perf_analyzer is not currently supported on Windows because "
"is requires functionalities that are UNIX specific.")
else()
add_subdirectory(client_backend)
find_package(Git REQUIRED)
execute_process(COMMAND
"${GIT_EXECUTABLE}" log -n 1 --abbrev-commit --format=format:%h
RESULT_VARIABLE RETURN_CODE
OUTPUT_VARIABLE GIT_SHA)
if(NOT RETURN_CODE EQUAL "0")
set(GIT_SHA "unknown")
endif()
set(
PERF_ANALYZER_SRCS
command_line_parser.cc
perf_analyzer.cc
model_parser.cc
perf_utils.cc
load_manager.cc
data_loader.cc
concurrency_manager.cc
request_rate_manager.cc
load_worker.cc
concurrency_worker.cc
request_rate_worker.cc
custom_load_manager.cc
infer_context.cc
inference_profiler.cc
report_writer.cc
mpi_utils.cc
metrics_manager.cc
infer_data_manager_base.cc
infer_data_manager.cc
infer_data_manager_shm.cc
sequence_manager.cc
profile_data_collector.cc
profile_data_exporter.cc
)
set(
PERF_ANALYZER_HDRS
command_line_parser.h
perf_analyzer.h
model_parser.h
perf_utils.h
load_manager.h
data_loader.h
concurrency_manager.h
request_rate_manager.h
custom_load_manager.h
iworker.h
load_worker.h
request_rate_worker.h
concurrency_worker.h
infer_context.h
inference_profiler.h
report_writer.h
mpi_utils.h
doctest.h
constants.h
metrics.h
metrics_manager.h
infer_data_manager_factory.h
iinfer_data_manager.h
infer_data_manager.h
infer_data_manager_shm.h
infer_data_manager_base.h
infer_data.h
sequence_manager.h
sequence_status.h
ictx_id_tracker.h
concurrency_ctx_id_tracker.h
fifo_ctx_id_tracker.h
rand_ctx_id_tracker.h
request_record.h
profile_data_collector.h
profile_data_exporter.h
)
add_executable(
perf_analyzer
main.cc
${PERF_ANALYZER_SRCS}
${PERF_ANALYZER_HDRS}
$<TARGET_OBJECTS:json-utils-library>
)
target_link_libraries(
perf_analyzer
PRIVATE
client-backend-library
-lb64
${CMAKE_DL_LIBS}
)
target_compile_definitions(
perf_analyzer
PRIVATE
PERF_ANALYZER_VERSION=${PERF_ANALYZER_VERSION}
GIT_SHA=${GIT_SHA}
)
# If gpu is enabled then compile with CUDA dependencies
if(TRITON_ENABLE_GPU)
target_compile_definitions(
perf_analyzer
PUBLIC TRITON_ENABLE_GPU=1
)
target_link_libraries(
perf_analyzer
PRIVATE CUDA::cudart
)
endif()
if(TRITON_ENABLE_PERF_ANALYZER_C_API)
target_compile_definitions(
client-backend-library
PUBLIC TRITON_ENABLE_PERF_ANALYZER_C_API=1
)
endif()
if(TRITON_ENABLE_PERF_ANALYZER_TFS)
target_compile_definitions(
client-backend-library
PUBLIC TRITON_ENABLE_PERF_ANALYZER_TFS=1
)
endif()
if(TRITON_ENABLE_PERF_ANALYZER_TS)
target_compile_definitions(
client-backend-library
PUBLIC TRITON_ENABLE_PERF_ANALYZER_TS=1
)
endif()
install(
TARGETS perf_analyzer
RUNTIME DESTINATION bin
)
target_compile_definitions(perf_analyzer PUBLIC DOCTEST_CONFIG_DISABLE)
# Creating perf_client link to perf_analyzer binary for backwards compatibility.
install(CODE "execute_process(COMMAND ${CMAKE_COMMAND} -E create_symlink ./perf_analyzer perf_client
WORKING_DIRECTORY ${CMAKE_INSTALL_PREFIX}/bin/)")
install(CODE "message(\"-- Created symlink: perf_client -> ./perf_analyzer\")")
set(PERF_ANALYZER_UNIT_TESTS_SRCS ${PERF_ANALYZER_SRCS})
list(PREPEND PERF_ANALYZER_UNIT_TESTS_SRCS perf_analyzer_unit_tests.cc)
set(PERF_ANALYZER_UNIT_TESTS_HDRS ${PERF_ANALYZER_HDRS})
add_executable(
perf_analyzer_unit_tests
${PERF_ANALYZER_UNIT_TESTS_SRCS}
${PERF_ANALYZER_UNIT_TESTS_HDRS}
mock_inference_profiler.h
mock_model_parser.h
test_utils.h
client_backend/mock_client_backend.h
mock_concurrency_worker.h
mock_data_loader.h
mock_infer_context.h
mock_infer_data_manager.h
mock_request_rate_worker.h
mock_sequence_manager.h
mock_profile_data_collector.h
mock_profile_data_exporter.h
test_dataloader.cc
test_inference_profiler.cc
test_command_line_parser.cc
test_idle_timer.cc
test_load_manager_base.h
test_load_manager.cc
test_model_parser.cc
test_metrics_manager.cc
test_perf_utils.cc
test_report_writer.cc
client_backend/triton/test_triton_client_backend.cc
test_request_rate_manager.cc
test_concurrency_manager.cc
test_custom_load_manager.cc
test_sequence_manager.cc
test_infer_context.cc
test_ctx_id_tracker.cc
test_profile_data_collector.cc
test_profile_data_exporter.cc
$<TARGET_OBJECTS:json-utils-library>
)
# -Wno-write-strings is needed for the unit tests in order to statically create
# input argv cases in the CommandLineParser unit test
#
set_target_properties(perf_analyzer_unit_tests
PROPERTIES COMPILE_FLAGS "-Wno-write-strings")
target_link_libraries(
perf_analyzer_unit_tests
PRIVATE
gmock
client-backend-library
-lb64
)
target_include_directories(
perf_analyzer_unit_tests
PRIVATE
client_backend
)
install(
TARGETS perf_analyzer_unit_tests
RUNTIME DESTINATION bin
)
endif()
<!--
Copyright (c) 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions
are met:
* Redistributions of source code must retain the above copyright
notice, this list of conditions and the following disclaimer.
* Redistributions in binary form must reproduce the above copyright
notice, this list of conditions and the following disclaimer in the
documentation and/or other materials provided with the distribution.
* Neither the name of NVIDIA CORPORATION nor the names of its
contributors may be used to endorse or promote products derived
from this software without specific prior written permission.
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->
# Triton Performance Analyzer
Triton Performance Analyzer is CLI tool which can help you optimize the
inference performance of models running on Triton Inference Server by measuring
changes in performance as you experiment with different optimization strategies.
<br>
# Features
### Inference Load Modes
- [Concurrency Mode](docs/inference_load_modes.md#concurrency-mode) simlulates
load by maintaining a specific concurrency of outgoing requests to the
server
- [Request Rate Mode](docs/inference_load_modes.md#request-rate-mode) simulates
load by sending consecutive requests at a specific rate to the server
- [Custom Interval Mode](docs/inference_load_modes.md#custom-interval-mode)
simulates load by sending consecutive requests at specific intervals to the
server
### Performance Measurement Modes
- [Time Windows Mode](docs/measurements_metrics.md#time-windows) measures model
performance repeatedly over a specific time interval until performance has
stabilized
- [Count Windows Mode](docs/measurements_metrics.md#count-windows) measures
model performance repeatedly over a specific number of requests until
performance has stabilized
### Other Features
- [Sequence Models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#stateful-models)
and
[Ensemble Models](https://github.com/triton-inference-server/server/blob/main/docs/user_guide/architecture.md#ensemble-models)
can be profiled in addition to standard/stateless models
- [Input Data](docs/input_data.md) to model inferences can be auto-generated or
specified as well as verifying output
- [TensorFlow Serving](docs/benchmarking.md#benchmarking-tensorflow-serving) and
[TorchServe](docs/benchmarking.md#benchmarking-torchserve) can be used as the
inference server in addition to the default Triton server
<br>
# Quick Start
The steps below will guide you on how to start using Perf Analyzer.
### Step 1: Start Triton Container
```bash
export RELEASE=<yy.mm> # e.g. to use the release from the end of February of 2023, do `export RELEASE=23.02`
docker pull nvcr.io/nvidia/tritonserver:${RELEASE}-py3
docker run --gpus all --rm -it --net host nvcr.io/nvidia/tritonserver:${RELEASE}-py3
```
### Step 2: Download `simple` Model
```bash
# inside triton container
git clone --depth 1 https://github.com/triton-inference-server/server
mkdir model_repository ; cp -r server/docs/examples/model_repository/simple model_repository
```
### Step 3: Start Triton Server
```bash
# inside triton container
tritonserver --model-repository $(pwd)/model_repository &> server.log &
# confirm server is ready, look for 'HTTP/1.1 200 OK'
curl -v localhost:8000/v2/health/ready
# detach (CTRL-p CTRL-q)
```
### Step 4: Start Triton SDK Container
```bash
docker pull nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk
docker run --gpus all --rm -it --net host nvcr.io/nvidia/tritonserver:${RELEASE}-py3-sdk
```
### Step 5: Run Perf Analyzer
```bash
# inside sdk container
perf_analyzer -m simple
```
See the full [quick start guide](docs/quick_start.md) for additional tips on
how to analyze output.
<br>
# Documentation
- [Installation](docs/install.md)
- [Perf Analyzer CLI](docs/cli.md)
- [Inference Load Modes](docs/inference_load_modes.md)
- [Input Data](docs/input_data.md)
- [Measurements & Metrics](docs/measurements_metrics.md)
- [Benchmarking](docs/benchmarking.md)
<br>
# Contributing
Contributions to Triton Perf Analyzer are more than welcome. To contribute
please review the [contribution
guidelines](https://github.com/triton-inference-server/server/blob/main/CONTRIBUTING.md),
then fork and create a pull request.
<br>
# Reporting problems, asking questions
We appreciate any feedback, questions or bug reporting regarding this
project. When help with code is needed, follow the process outlined in
the Stack Overflow (https://stackoverflow.com/help/mcve)
document. Ensure posted examples are:
- minimal - use as little code as possible that still produces the
same problem
- complete - provide all parts needed to reproduce the problem. Check
if you can strip external dependency and still show the problem. The
less time we spend on reproducing problems the more time we have to
fix it
- verifiable - test the code you're about to provide to make sure it
reproduces the problem. Remove all other problems that are not
related to your request/question.
// Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <queue>
#include "ictx_id_tracker.h"
namespace triton { namespace perfanalyzer {
// Base class for CtxIdTrackers that track available IDs via a queue
//
class BaseQueueCtxIdTracker : public ICtxIdTracker {
public:
BaseQueueCtxIdTracker() = default;
void Restore(size_t id) override { free_ctx_ids_.push(id); }
size_t Get() override
{
if (!IsAvailable()) {
throw std::runtime_error("free ctx id list is empty");
}
size_t ctx_id = free_ctx_ids_.front();
free_ctx_ids_.pop();
return ctx_id;
}
bool IsAvailable() override { return free_ctx_ids_.size() > 0; }
protected:
std::queue<size_t> free_ctx_ids_;
// Erase all entries in the tracking queue
//
void Clear()
{
std::queue<size_t> empty;
std::swap(free_ctx_ids_, empty);
}
};
}}; // namespace triton::perfanalyzer
# Copyright 2020-2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18)
# fixme
add_definitions(-DCURL_STATICLIB)
add_subdirectory(triton)
if(TRITON_ENABLE_PERF_ANALYZER_C_API)
add_subdirectory(triton_c_api)
endif()
if(TRITON_ENABLE_PERF_ANALYZER_TFS)
add_subdirectory(tensorflow_serving)
endif()
if(TRITON_ENABLE_PERF_ANALYZER_TS)
add_subdirectory(torchserve)
endif()
set(
CLIENT_BACKEND_SRCS
client_backend.cc
)
set(
CLIENT_BACKEND_HDRS
client_backend.h
)
if(TRITON_ENABLE_PERF_ANALYZER_C_API)
set(CAPI_LIBRARY $<TARGET_OBJECTS:triton-c-api-backend-library>)
set(CAPI_TARGET_LINK_LIBRARY PUBLIC $<TARGET_PROPERTY:triton-c-api-backend-library,LINK_LIBRARIES>)
set(CAPI_TARGET_INCLUDE_DIRECTORY PRIVATE $<TARGET_PROPERTY:triton-c-api-backend-library,INCLUDE_DIRECTORIES>)
endif()
if(TRITON_ENABLE_PERF_ANALYZER_TFS)
set(TFS_LIBRARY $<TARGET_OBJECTS:tfs-client-backend-library>)
set(TFS_TARGET_LINK_LIBRARY PUBLIC $<TARGET_PROPERTY:tfs-client-backend-library,LINK_LIBRARIES>)
set(TFS_TARGET_INCLUDE_DIRECTORY PRIVATE $<TARGET_PROPERTY:tfs-client-backend-library,INCLUDE_DIRECTORIES>)
endif()
if(TRITON_ENABLE_PERF_ANALYZER_TS)
set(TS_LIBRARY $<TARGET_OBJECTS:ts-client-backend-library>)
set(TS_TARGET_LINK_LIBRARY PUBLIC $<TARGET_PROPERTY:ts-client-backend-library,LINK_LIBRARIES>)
set(TS_TARGET_INCLUDE_DIRECTORY PRIVATE $<TARGET_PROPERTY:ts-client-backend-library,INCLUDE_DIRECTORIES>)
endif()
add_library(
client-backend-library
${CLIENT_BACKEND_SRCS}
${CLIENT_BACKEND_HDRS}
$<TARGET_OBJECTS:triton-client-backend-library>
$<TARGET_OBJECTS:shm-utils-library>
${CAPI_LIBRARY}
${TFS_LIBRARY}
${TS_LIBRARY}
)
target_link_libraries(
client-backend-library
PUBLIC triton-common-json # from repo-common
PUBLIC $<TARGET_PROPERTY:triton-client-backend-library,LINK_LIBRARIES>
${CAPI_TARGET_LINK_LIBRARY}
${TFS_TARGET_LINK_LIBRARY}
${TS_TARGET_LINK_LIBRARY}
)
target_include_directories(
client-backend-library
PRIVATE $<TARGET_PROPERTY:triton-client-backend-library,INCLUDE_DIRECTORIES>
${CAPI_TARGET_INCLUDE_DIRECTORY}
${TFS_TARGET_INCLUDE_DIRECTORY}
${TS_TARGET_INCLUDE_DIRECTORY}
)
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "client_backend.h"
#include "triton/triton_client_backend.h"
#ifdef TRITON_ENABLE_PERF_ANALYZER_C_API
#include "triton_c_api/triton_c_api_backend.h"
#endif // TRITON_ENABLE_PERF_ANALYZER_C_API
#ifdef TRITON_ENABLE_PERF_ANALYZER_TFS
#include "tensorflow_serving/tfserve_client_backend.h"
#endif // TRITON_ENABLE_PERF_ANALYZER_TFS
#ifdef TRITON_ENABLE_PERF_ANALYZER_TS
#include "torchserve/torchserve_client_backend.h"
#endif // TRITON_ENABLE_PERF_ANALYZER_TS
namespace triton { namespace perfanalyzer { namespace clientbackend {
//================================================
const Error Error::Success("", pa::SUCCESS);
const Error Error::Failure("", pa::GENERIC_ERROR);
Error::Error() : msg_(""), error_(pa::SUCCESS) {}
Error::Error(const std::string& msg, const uint32_t err)
: msg_(msg), error_(err)
{
}
Error::Error(const std::string& msg) : msg_(msg)
{
error_ = pa::GENERIC_ERROR;
}
std::ostream&
operator<<(std::ostream& out, const Error& err)
{
if (!err.msg_.empty()) {
out << err.msg_ << std::endl;
}
return out;
}
//================================================
std::string
BackendKindToString(const BackendKind kind)
{
switch (kind) {
case TRITON:
return std::string("TRITON");
break;
case TENSORFLOW_SERVING:
return std::string("TENSORFLOW_SERVING");
break;
case TORCHSERVE:
return std::string("TORCHSERVE");
break;
case TRITON_C_API:
return std::string("TRITON_C_API");
break;
default:
return std::string("UNKNOWN");
break;
}
}
grpc_compression_algorithm
BackendToGrpcType(const GrpcCompressionAlgorithm compression_algorithm)
{
switch (compression_algorithm) {
case COMPRESS_DEFLATE:
return grpc_compression_algorithm::GRPC_COMPRESS_DEFLATE;
case COMPRESS_GZIP:
return grpc_compression_algorithm::GRPC_COMPRESS_GZIP;
default:
return grpc_compression_algorithm::GRPC_COMPRESS_NONE;
}
}
//================================================
//
// ClientBackendFactory
//
Error
ClientBackendFactory::Create(
const BackendKind kind, const std::string& url, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers,
const std::string& triton_server_path,
const std::string& model_repository_path, const bool verbose,
const std::string& metrics_url, const cb::TensorFormat input_tensor_format,
const cb::TensorFormat output_tensor_format,
std::shared_ptr<ClientBackendFactory>* factory)
{
factory->reset(new ClientBackendFactory(
kind, url, protocol, ssl_options, trace_options, compression_algorithm,
http_headers, triton_server_path, model_repository_path, verbose,
metrics_url, input_tensor_format, output_tensor_format));
return Error::Success;
}
Error
ClientBackendFactory::CreateClientBackend(
std::unique_ptr<ClientBackend>* client_backend)
{
RETURN_IF_CB_ERROR(ClientBackend::Create(
kind_, url_, protocol_, ssl_options_, trace_options_,
compression_algorithm_, http_headers_, verbose_, triton_server_path,
model_repository_path_, metrics_url_, input_tensor_format_,
output_tensor_format_, client_backend));
return Error::Success;
}
const BackendKind&
ClientBackendFactory::Kind()
{
return kind_;
}
//
// ClientBackend
//
Error
ClientBackend::Create(
const BackendKind kind, const std::string& url, const ProtocolType protocol,
const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers, const bool verbose,
const std::string& triton_server_path,
const std::string& model_repository_path, const std::string& metrics_url,
const TensorFormat input_tensor_format,
const TensorFormat output_tensor_format,
std::unique_ptr<ClientBackend>* client_backend)
{
std::unique_ptr<ClientBackend> local_backend;
if (kind == TRITON) {
RETURN_IF_CB_ERROR(tritonremote::TritonClientBackend::Create(
url, protocol, ssl_options, trace_options,
BackendToGrpcType(compression_algorithm), http_headers, verbose,
metrics_url, input_tensor_format, output_tensor_format,
&local_backend));
}
#ifdef TRITON_ENABLE_PERF_ANALYZER_TFS
else if (kind == TENSORFLOW_SERVING) {
RETURN_IF_CB_ERROR(tfserving::TFServeClientBackend::Create(
url, protocol, BackendToGrpcType(compression_algorithm), http_headers,
verbose, &local_backend));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_TFS
#ifdef TRITON_ENABLE_PERF_ANALYZER_TS
else if (kind == TORCHSERVE) {
RETURN_IF_CB_ERROR(torchserve::TorchServeClientBackend::Create(
url, protocol, http_headers, verbose, &local_backend));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_TS
#ifdef TRITON_ENABLE_PERF_ANALYZER_C_API
else if (kind == TRITON_C_API) {
RETURN_IF_CB_ERROR(tritoncapi::TritonCApiClientBackend::Create(
triton_server_path, model_repository_path, verbose, &local_backend));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_C_API
else {
return Error("unsupported client backend requested", pa::GENERIC_ERROR);
}
*client_backend = std::move(local_backend);
return Error::Success;
}
Error
ClientBackend::ServerExtensions(std::set<std::string>* server_extensions)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support ServerExtensions API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::ModelMetadata(
rapidjson::Document* model_metadata, const std::string& model_name,
const std::string& model_version)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support ModelMetadata API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::ModelConfig(
rapidjson::Document* model_config, const std::string& model_name,
const std::string& model_version)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support ModelConfig API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support Infer API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support AsyncInfer API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::StartStream(OnCompleteFn callback, bool enable_stats)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support StartStream API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::AsyncStreamInfer(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support AsyncStreamInfer API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::ClientInferStat(InferStat* infer_stat)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support ClientInferStat API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::ModelInferenceStatistics(
std::map<ModelIdentifier, ModelStatistics>* model_stats,
const std::string& model_name, const std::string& model_version)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support ModelInferenceStatistics API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::Metrics(triton::perfanalyzer::Metrics& metrics)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support Metrics API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::UnregisterAllSharedMemory()
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support UnregisterAllSharedMemory API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::RegisterSystemSharedMemory(
const std::string& name, const std::string& key, const size_t byte_size)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support RegisterSystemSharedMemory API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& handle,
const size_t byte_size)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support RegisterCudaSharedMemory API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::RegisterCudaMemory(
const std::string& name, void* handle, const size_t byte_size)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support RegisterCudaMemory API",
pa::GENERIC_ERROR);
}
Error
ClientBackend::RegisterSystemMemory(
const std::string& name, void* memory_ptr, const size_t byte_size)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support RegisterCudaMemory API",
pa::GENERIC_ERROR);
}
//
// Shared Memory Utilities
//
Error
ClientBackend::CreateSharedMemoryRegion(
std::string shm_key, size_t byte_size, int* shm_fd)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support CreateSharedMemoryRegion()",
pa::GENERIC_ERROR);
}
Error
ClientBackend::MapSharedMemory(
int shm_fd, size_t offset, size_t byte_size, void** shm_addr)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support MapSharedMemory()",
pa::GENERIC_ERROR);
}
Error
ClientBackend::CloseSharedMemory(int shm_fd)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support CloseSharedMemory()",
pa::GENERIC_ERROR);
}
Error
ClientBackend::UnlinkSharedMemoryRegion(std::string shm_key)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support UnlinkSharedMemoryRegion()",
pa::GENERIC_ERROR);
}
Error
ClientBackend::UnmapSharedMemory(void* shm_addr, size_t byte_size)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support UnmapSharedMemory()",
pa::GENERIC_ERROR);
}
ClientBackend::ClientBackend(const BackendKind kind) : kind_(kind) {}
//
// InferInput
//
Error
InferInput::Create(
InferInput** infer_input, const BackendKind kind, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype)
{
if (kind == TRITON) {
RETURN_IF_CB_ERROR(tritonremote::TritonInferInput::Create(
infer_input, name, dims, datatype));
}
#ifdef TRITON_ENABLE_PERF_ANALYZER_TFS
else if (kind == TENSORFLOW_SERVING) {
RETURN_IF_CB_ERROR(tfserving::TFServeInferInput::Create(
infer_input, name, dims, datatype));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_TFS
#ifdef TRITON_ENABLE_PERF_ANALYZER_TS
else if (kind == TORCHSERVE) {
RETURN_IF_CB_ERROR(torchserve::TorchServeInferInput::Create(
infer_input, name, dims, datatype));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_TS
#ifdef TRITON_ENABLE_PERF_ANALYZER_C_API
else if (kind == TRITON_C_API) {
RETURN_IF_CB_ERROR(tritoncapi::TritonCApiInferInput::Create(
infer_input, name, dims, datatype));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_C_API
else {
return Error(
"unsupported client backend provided to create InferInput object",
pa::GENERIC_ERROR);
}
return Error::Success;
}
Error
InferInput::SetShape(const std::vector<int64_t>& shape)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support SetShape() for InferInput",
pa::GENERIC_ERROR);
}
Error
InferInput::Reset()
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support Reset() for InferInput",
pa::GENERIC_ERROR);
}
Error
InferInput::AppendRaw(const uint8_t* input, size_t input_byte_size)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support AppendRaw() for InferInput",
pa::GENERIC_ERROR);
}
Error
InferInput::SetSharedMemory(
const std::string& name, size_t byte_size, size_t offset)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support SetSharedMemory() for InferInput",
pa::GENERIC_ERROR);
}
InferInput::InferInput(
const BackendKind kind, const std::string& name,
const std::string& datatype)
: kind_(kind), name_(name), datatype_(datatype)
{
}
//
// InferRequestedOutput
//
Error
InferRequestedOutput::Create(
InferRequestedOutput** infer_output, const BackendKind kind,
const std::string& name, const size_t class_count)
{
if (kind == TRITON) {
RETURN_IF_CB_ERROR(tritonremote::TritonInferRequestedOutput::Create(
infer_output, name, class_count));
}
#ifdef TRITON_ENABLE_PERF_ANALYZER_TFS
else if (kind == TENSORFLOW_SERVING) {
RETURN_IF_CB_ERROR(
tfserving::TFServeInferRequestedOutput::Create(infer_output, name));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_TFS
#ifdef TRITON_ENABLE_PERF_ANALYZER_C_API
else if (kind == TRITON_C_API) {
RETURN_IF_CB_ERROR(tritoncapi::TritonCApiInferRequestedOutput::Create(
infer_output, name, class_count));
}
#endif // TRITON_ENABLE_PERF_ANALYZER_C_API
else {
return Error(
"unsupported client backend provided to create InferRequestedOutput "
"object",
pa::GENERIC_ERROR);
}
return Error::Success;
}
Error
InferRequestedOutput::SetSharedMemory(
const std::string& region_name, size_t byte_size, size_t offset)
{
return Error(
"client backend of kind " + BackendKindToString(kind_) +
" does not support SetSharedMemory() for InferRequestedOutput",
pa::GENERIC_ERROR);
}
InferRequestedOutput::InferRequestedOutput(
const BackendKind kind, const std::string& name)
: kind_(kind), name_(name)
{
}
}}} // namespace triton::perfanalyzer::clientbackend
// Copyright 2020-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <rapidjson/document.h>
#include <rapidjson/rapidjson.h>
#include <functional>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <string>
#include <vector>
#include "../constants.h"
#include "../metrics.h"
#include "../perf_analyzer_exception.h"
#include "ipc.h"
namespace pa = triton::perfanalyzer;
namespace triton { namespace perfanalyzer { namespace clientbackend {
#define RETURN_IF_CB_ERROR(S) \
do { \
const triton::perfanalyzer::clientbackend::Error& status__ = (S); \
if (!status__.IsOk()) { \
return status__; \
} \
} while (false)
#define RETURN_IF_ERROR(S) \
do { \
triton::perfanalyzer::clientbackend::Error status__ = (S); \
if (!status__.IsOk()) { \
return status__; \
} \
} while (false)
#define FAIL_IF_ERR(X, MSG) \
{ \
triton::perfanalyzer::clientbackend::Error err = (X); \
if (!err.IsOk()) { \
std::cerr << "error: " << (MSG) << ": " << err << std::endl; \
exit(err.Err()); \
} \
} \
while (false)
#define THROW_IF_ERROR(S, MSG) \
do { \
triton::perfanalyzer::clientbackend::Error status__ = (S); \
if (!status__.IsOk()) { \
std::cerr << "error: " << (MSG) << ": " << status__ << std::endl; \
throw PerfAnalyzerException(GENERIC_ERROR); \
} \
} while (false)
//==============================================================================
/// Error status reported by backends
///
class Error {
public:
/// Create an error
explicit Error();
/// Create an error with the specified message and error code.
/// \param msg The message for the error
/// \param err The error code for the error
explicit Error(const std::string& msg, const uint32_t err);
/// Create an error with the specified message.
/// \param msg The message for the error
explicit Error(const std::string& msg);
/// Accessor for the message of this error.
/// \return The message for the error. Empty if no error.
const std::string& Message() const { return msg_; }
/// Accessor for the error code.
/// \return The error code for the error. 0 if no error.
const uint32_t Err() const { return error_; }
/// Does this error indicate OK status?
/// \return True if this error indicates "ok"/"success", false if
/// error indicates a failure.
bool IsOk() const { return error_ == 0; }
/// Convenience "success" value. Can be used as Error::Success to
/// indicate no error.
static const Error Success;
/// Convenience "failure" value. Can be used as Error::Failure to
/// indicate a generic error.
static const Error Failure;
private:
friend std::ostream& operator<<(std::ostream&, const Error&);
std::string msg_{""};
uint32_t error_{pa::SUCCESS};
};
//===================================================================================
class ClientBackend;
class InferInput;
class InferRequestedOutput;
class InferResult;
enum BackendKind {
TRITON = 0,
TENSORFLOW_SERVING = 1,
TORCHSERVE = 2,
TRITON_C_API = 3
};
enum ProtocolType { HTTP = 0, GRPC = 1, UNKNOWN = 2 };
enum GrpcCompressionAlgorithm {
COMPRESS_NONE = 0,
COMPRESS_DEFLATE = 1,
COMPRESS_GZIP = 2
};
enum class TensorFormat { BINARY, JSON, UNKNOWN };
typedef std::map<std::string, std::string> Headers;
using OnCompleteFn = std::function<void(InferResult*)>;
using ModelIdentifier = std::pair<std::string, std::string>;
struct InferStat {
/// Total number of requests completed.
size_t completed_request_count;
/// Time from the request start until the response is completely
/// received.
uint64_t cumulative_total_request_time_ns;
/// Time from the request start until the last byte is sent.
uint64_t cumulative_send_time_ns;
/// Time from receiving first byte of the response until the
/// response is completely received.
uint64_t cumulative_receive_time_ns;
/// Create a new InferStat object with zero-ed statistics.
InferStat()
: completed_request_count(0), cumulative_total_request_time_ns(0),
cumulative_send_time_ns(0), cumulative_receive_time_ns(0)
{
}
};
// Per model statistics
struct ModelStatistics {
uint64_t success_count_;
uint64_t inference_count_;
uint64_t execution_count_;
uint64_t queue_count_;
uint64_t compute_input_count_;
uint64_t compute_infer_count_;
uint64_t compute_output_count_;
uint64_t cache_hit_count_;
uint64_t cache_miss_count_;
uint64_t cumm_time_ns_;
uint64_t queue_time_ns_;
uint64_t compute_input_time_ns_;
uint64_t compute_infer_time_ns_;
uint64_t compute_output_time_ns_;
uint64_t cache_hit_time_ns_;
uint64_t cache_miss_time_ns_;
};
//==============================================================================
/// Structure to hold options for Inference Request.
///
struct InferOptions {
explicit InferOptions(const std::string& model_name)
: model_name_(model_name), model_version_(""), request_id_(""),
sequence_id_(0), sequence_id_str_(""), sequence_start_(false),
sequence_end_(false), triton_enable_empty_final_response_(true)
{
}
/// The name of the model to run inference.
std::string model_name_;
/// The version of the model.
std::string model_version_;
/// The model signature name for TF models.
std::string model_signature_name_;
/// An identifier for the request.
std::string request_id_;
/// The unique identifier for the sequence being represented by the
/// object. Default value is 0 which means that the request does not
/// belong to a sequence. If this value is set, then sequence_id_str_
/// MUST be set to "".
uint64_t sequence_id_;
/// The unique identifier for the sequence being represented by the
/// object. Default value is "" which means that the request does not
/// belong to a sequence. If this value is set, then sequence_id_ MUST
/// be set to 0.
std::string sequence_id_str_;
/// Indicates whether the request being added marks the start of the
/// sequence. Default value is False. This argument is ignored if
/// 'sequence_id' is 0.
bool sequence_start_;
/// Indicates whether the request being added marks the end of the
/// sequence. Default value is False. This argument is ignored if
/// 'sequence_id' is 0.
bool sequence_end_;
/// Whether to tell Triton to enable an empty final response.
bool triton_enable_empty_final_response_;
};
struct SslOptionsBase {
bool ssl_grpc_use_ssl = false;
std::string ssl_grpc_root_certifications_file = "";
std::string ssl_grpc_private_key_file = "";
std::string ssl_grpc_certificate_chain_file = "";
long ssl_https_verify_peer = 1L;
long ssl_https_verify_host = 2L;
std::string ssl_https_ca_certificates_file = "";
std::string ssl_https_client_certificate_file = "";
std::string ssl_https_client_certificate_type = "";
std::string ssl_https_private_key_file = "";
std::string ssl_https_private_key_type = "";
};
//
// The object factory to create client backends to communicate with the
// inference service
//
class ClientBackendFactory {
public:
/// Create a factory that can be used to construct Client Backends.
/// \param kind The kind of client backend to create.
/// \param url The inference server url and port.
/// \param protocol The protocol type used.
/// \param ssl_options The SSL options used with client backend.
/// \param compression_algorithm The compression algorithm to be used
/// on the grpc requests.
/// \param http_headers Map of HTTP headers. The map key/value
/// indicates the header name/value. The headers will be included
/// with all the requests made to server using this client.
/// \param triton_server_path Only for C api backend. Lbrary path to
/// path to the top-level Triton directory (which is typically
/// /opt/tritonserver) Must contain libtritonserver.so.
/// \param model_repository_path Only for C api backend. Path to model
/// repository which contains the desired model.
/// \param verbose Enables the verbose mode.
/// \param metrics_url The inference server metrics url and port.
/// \param input_tensor_format The Triton inference request input tensor
/// format.
/// \param output_tensor_format The Triton inference response output tensor
/// format.
/// \param factory Returns a new ClientBackend object.
/// \return Error object indicating success or failure.
static Error Create(
const BackendKind kind, const std::string& url,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers,
const std::string& triton_server_path,
const std::string& model_repository_path, const bool verbose,
const std::string& metrics_url, const TensorFormat input_tensor_format,
const TensorFormat output_tensor_format,
std::shared_ptr<ClientBackendFactory>* factory);
const BackendKind& Kind();
/// Create a ClientBackend.
/// \param backend Returns a new Client backend object.
virtual Error CreateClientBackend(std::unique_ptr<ClientBackend>* backend);
private:
ClientBackendFactory(
const BackendKind kind, const std::string& url,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
const std::shared_ptr<Headers> http_headers,
const std::string& triton_server_path,
const std::string& model_repository_path, const bool verbose,
const std::string& metrics_url, const TensorFormat input_tensor_format,
const TensorFormat output_tensor_format)
: kind_(kind), url_(url), protocol_(protocol), ssl_options_(ssl_options),
trace_options_(trace_options),
compression_algorithm_(compression_algorithm),
http_headers_(http_headers), triton_server_path(triton_server_path),
model_repository_path_(model_repository_path), verbose_(verbose),
metrics_url_(metrics_url), input_tensor_format_(input_tensor_format),
output_tensor_format_(output_tensor_format)
{
}
const BackendKind kind_;
const std::string url_;
const ProtocolType protocol_;
const SslOptionsBase& ssl_options_;
const std::map<std::string, std::vector<std::string>> trace_options_;
const GrpcCompressionAlgorithm compression_algorithm_;
std::shared_ptr<Headers> http_headers_;
std::string triton_server_path;
std::string model_repository_path_;
const bool verbose_;
const std::string metrics_url_{""};
const TensorFormat input_tensor_format_{TensorFormat::UNKNOWN};
const TensorFormat output_tensor_format_{TensorFormat::UNKNOWN};
#ifndef DOCTEST_CONFIG_DISABLE
protected:
ClientBackendFactory()
: kind_(BackendKind()), url_(""), protocol_(ProtocolType()),
ssl_options_(SslOptionsBase()),
trace_options_(std::map<std::string, std::vector<std::string>>()),
compression_algorithm_(GrpcCompressionAlgorithm()), verbose_(false)
{
}
#endif
};
//
// Interface for interacting with an inference service
//
class ClientBackend {
public:
static Error Create(
const BackendKind kind, const std::string& url,
const ProtocolType protocol, const SslOptionsBase& ssl_options,
const std::map<std::string, std::vector<std::string>> trace_options,
const GrpcCompressionAlgorithm compression_algorithm,
std::shared_ptr<Headers> http_headers, const bool verbose,
const std::string& library_directory, const std::string& model_repository,
const std::string& metrics_url, const TensorFormat input_tensor_format,
const TensorFormat output_tensor_format,
std::unique_ptr<ClientBackend>* client_backend);
/// Destructor for the client backend object
virtual ~ClientBackend() = default;
/// Get the backend kind
BackendKind Kind() const { return kind_; }
/// Get the server metadata from the server
virtual Error ServerExtensions(std::set<std::string>* server_extensions);
/// Get the model metadata from the server for specified name and
/// version as rapidjson DOM object.
virtual Error ModelMetadata(
rapidjson::Document* model_metadata, const std::string& model_name,
const std::string& model_version);
/// Get the model config from the server for specified name and
/// version as rapidjson DOM object.
virtual Error ModelConfig(
rapidjson::Document* model_config, const std::string& model_name,
const std::string& model_version);
/// Issues a synchronous inference request to the server.
virtual Error Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs);
/// Issues an asynchronous inference request to the server.
virtual Error AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs);
/// Established a stream to the server.
virtual Error StartStream(OnCompleteFn callback, bool enable_stats);
/// Issues an asynchronous inference request to the underlying stream.
virtual Error AsyncStreamInfer(
const InferOptions& options, const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs);
/// Gets the client side inference statistics from the client library.
virtual Error ClientInferStat(InferStat* infer_stat);
/// Gets the server-side model inference statistics from the server.
virtual Error ModelInferenceStatistics(
std::map<ModelIdentifier, ModelStatistics>* model_stats,
const std::string& model_name = "",
const std::string& model_version = "");
/// Gets the server-side metrics from the server.
/// \param metrics Output metrics object.
/// \return Error object indicating success or failure.
virtual Error Metrics(Metrics& metrics);
/// Unregisters all the shared memory from the server
virtual Error UnregisterAllSharedMemory();
/// Registers a system shared memory from the server
virtual Error RegisterSystemSharedMemory(
const std::string& name, const std::string& key, const size_t byte_size);
/// Registers cuda shared memory to the server.
virtual Error RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& handle,
const size_t byte_size);
/// Registers cuda memory to the server.
virtual Error RegisterCudaMemory(
const std::string& name, void* handle, const size_t byte_size);
/// Registers a system memory location on the server.
virtual Error RegisterSystemMemory(
const std::string& name, void* memory_ptr, const size_t byte_size);
//
// Shared Memory Utilities
//
// FIXME: These should probably move to a common area with shm_utils not
// tied specifically to inferenceserver. Create a shared memory region of
// the size 'byte_size' and return the unique identifier.
virtual Error CreateSharedMemoryRegion(
std::string shm_key, size_t byte_size, int* shm_fd);
// Mmap the shared memory region with the given 'offset' and 'byte_size' and
// return the base address of the region.
// \param shm_fd The int descriptor of the created shared memory region
// \param offset The offset of the shared memory block from the start of the
// shared memory region
// \param byte_size The size in bytes of the shared memory region
// \param shm_addr Returns the base address of the shared memory region
// \return error Returns an error if unable to mmap shared memory region.
virtual Error MapSharedMemory(
int shm_fd, size_t offset, size_t byte_size, void** shm_addr);
// Close the shared memory descriptor.
// \param shm_fd The int descriptor of the created shared memory region
// \return error Returns an error if unable to close shared memory descriptor.
virtual Error CloseSharedMemory(int shm_fd);
// Destroy the shared memory region with the given name.
// \return error Returns an error if unable to unlink shared memory region.
virtual Error UnlinkSharedMemoryRegion(std::string shm_key);
// Munmap the shared memory region from the base address with the given
// byte_size.
// \return error Returns an error if unable to unmap shared memory region.
virtual Error UnmapSharedMemory(void* shm_addr, size_t byte_size);
protected:
/// Constructor for client backend
ClientBackend(const BackendKind kind);
// The kind of the backend.
const BackendKind kind_{TRITON};
#ifndef DOCTEST_CONFIG_DISABLE
public:
ClientBackend() = default;
#endif
};
//
// Interface for preparing the inputs for inference to the backend
//
class InferInput {
public:
/// Create a InferInput instance that describes a model input.
/// \param infer_input Returns a new InferInput object.
/// \param kind The kind of the associated client backend.
/// \param name The name of input whose data will be described by this object.
/// \param dims The shape of the input.
/// \param datatype The datatype of the input.
/// \return Error object indicating success or failure.
static Error Create(
InferInput** infer_input, const BackendKind kind, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype);
virtual ~InferInput() = default;
/// Gets name of the associated input tensor.
/// \return The name of the tensor.
const std::string& Name() const { return name_; }
/// Gets datatype of the associated input tensor.
/// \return The datatype of the tensor.
const std::string& Datatype() const { return datatype_; }
/// Gets the shape of the input tensor.
/// \return The shape of the tensor.
virtual const std::vector<int64_t>& Shape() const = 0;
/// Set the shape of input associated with this object.
/// \param dims the vector of dims representing the new shape
/// of input.
/// \return Error object indicating success or failure of the
/// request.
virtual Error SetShape(const std::vector<int64_t>& dims);
/// Prepare this input to receive new tensor values. Forget any
/// existing values that were set by previous calls to SetSharedMemory()
/// or AppendRaw().
/// \return Error object indicating success or failure.
virtual Error Reset();
/// Append tensor values for this input from a byte array.
/// \param input The pointer to the array holding the tensor value.
/// \param input_byte_size The size of the array in bytes.
/// \return Error object indicating success or failure.
virtual Error AppendRaw(const uint8_t* input, size_t input_byte_size);
/// Set tensor values for this input by reference into a shared memory
/// region.
/// \param name The user-given name for the registered shared memory region
/// where the tensor values for this input is stored.
/// \param byte_size The size, in bytes of the input tensor data. Must
/// match the size expected for the input shape.
/// \param offset The offset into the shared memory region upto the start
/// of the input tensor values. The default value is 0.
/// \return Error object indicating success or failure
virtual Error SetSharedMemory(
const std::string& name, size_t byte_size, size_t offset = 0);
protected:
InferInput(
const BackendKind kind, const std::string& name,
const std::string& datatype_);
const BackendKind kind_;
const std::string name_;
const std::string datatype_;
};
//
// Interface for preparing the inputs for inference to the backend
//
class InferRequestedOutput {
public:
virtual ~InferRequestedOutput() = default;
/// Create a InferRequestedOutput instance that describes a model output being
/// requested.
/// \param infer_output Returns a new InferOutputGrpc object.
/// \param kind The kind of the associated client backend.
/// \param name The name of output being requested.
/// \param class_count The number of classifications to be requested. The
/// default value is 0 which means the classification results are not
/// requested.
/// \return Error object indicating success or failure.
static Error Create(
InferRequestedOutput** infer_output, const BackendKind kind,
const std::string& name, const size_t class_count = 0);
/// Gets name of the associated output tensor.
/// \return The name of the tensor.
const std::string& Name() const { return name_; }
/// Set the output tensor data to be written to specified shared
/// memory region.
/// \param region_name The name of the shared memory region.
/// \param byte_size The size of data in bytes.
/// \param offset The offset in shared memory region. Default value is 0.
/// \return Error object indicating success or failure of the
/// request.
virtual Error SetSharedMemory(
const std::string& region_name, const size_t byte_size,
const size_t offset = 0);
protected:
InferRequestedOutput(const BackendKind kind, const std::string& name);
const BackendKind kind_;
const std::string name_;
};
//
// Interface for accessing the processed results.
//
class InferResult {
public:
virtual ~InferResult() = default;
/// Get the id of the request which generated this response.
/// \param id Returns the request id that generated the result.
/// \return Error object indicating success or failure.
virtual Error Id(std::string* id) const = 0;
/// Returns the status of the request.
/// \return Error object indicating the success or failure of the
/// request.
virtual Error RequestStatus() const = 0;
/// Returns the raw data of the output.
/// \return Error object indicating the success or failure of the
/// request.
virtual Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const = 0;
/// Get final response bool for this response.
/// \return Error object indicating the success or failure.
virtual Error IsFinalResponse(bool* is_final_response) const
{
return Error("InferResult::IsFinalResponse() not implemented");
};
/// Get null response bool for this response.
/// \return Error object indicating the success or failure.
virtual Error IsNullResponse(bool* is_null_response) const
{
return Error("InferResult::IsNullResponse() not implemented");
};
};
}}} // namespace triton::perfanalyzer::clientbackend
namespace cb = triton::perfanalyzer::clientbackend;
// Copyright 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in the
// documentation and/or other materials provided with the distribution.
// * Neither the name of NVIDIA CORPORATION nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
// EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
// IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
// PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
// EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
// PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
// PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
// OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#pragma once
#include <atomic>
#include <chrono>
#include <mutex>
#include <thread>
#include "../doctest.h"
#include "client_backend.h"
#include "gmock/gmock.h"
namespace triton { namespace perfanalyzer { namespace clientbackend {
// Holds information (either the raw data or a shared memory label) for an
// inference input
//
struct TestRecordedInput {
TestRecordedInput(int32_t data_in, size_t size_in)
: shared_memory_label(""), data(data_in), size(size_in)
{
}
TestRecordedInput(std::string label_in, size_t size_in)
: shared_memory_label(label_in), data(0), size(size_in)
{
}
std::string shared_memory_label;
int32_t data;
size_t size;
};
/// Mock class of an InferInput
///
class MockInferInput : public InferInput {
public:
MockInferInput(
const BackendKind kind, const std::string& name,
const std::vector<int64_t>& dims, const std::string& datatype)
: InferInput(kind, name, datatype), dims_(dims)
{
}
const std::vector<int64_t>& Shape() const override { return dims_; }
Error Reset() override
{
recorded_inputs_.clear();
return Error::Success;
}
Error AppendRaw(const uint8_t* input, size_t input_byte_size) override
{
if (input) {
int32_t val = *reinterpret_cast<const int32_t*>(input);
recorded_inputs_.push_back(TestRecordedInput(val, input_byte_size));
}
++append_raw_calls_;
return Error::Success;
}
Error SetSharedMemory(
const std::string& name, size_t byte_size, size_t offset = 0)
{
recorded_inputs_.push_back(TestRecordedInput(name, byte_size));
++set_shared_memory_calls_;
return Error::Success;
}
const std::vector<int64_t> dims_{};
std::vector<TestRecordedInput> recorded_inputs_{};
std::atomic<size_t> append_raw_calls_{0};
std::atomic<size_t> set_shared_memory_calls_{0};
};
/// Mock class of an InferResult
///
class MockInferResult : public InferResult {
public:
MockInferResult(const InferOptions& options) : req_id_(options.request_id_) {}
Error Id(std::string* id) const override
{
*id = req_id_;
return Error::Success;
}
Error RequestStatus() const override { return Error::Success; }
Error RawData(
const std::string& output_name, const uint8_t** buf,
size_t* byte_size) const override
{
return Error::Success;
}
Error IsFinalResponse(bool* is_final_response) const override
{
if (is_final_response == nullptr) {
return Error("is_final_response cannot be nullptr");
}
*is_final_response = true;
return Error::Success;
}
Error IsNullResponse(bool* is_null_response) const override
{
if (is_null_response == nullptr) {
return Error("is_null_response cannot be nullptr");
}
*is_null_response = false;
return Error::Success;
}
private:
std::string req_id_;
};
/// Class to track statistics of MockClientBackend
///
class MockClientStats {
public:
enum class ReqType { SYNC, ASYNC, ASYNC_STREAM };
struct SeqStatus {
// Set of all unique sequence IDs observed in requests
//
std::set<uint64_t> used_seq_ids;
// Map of all "live" sequence IDs (sequences that have started and not
// ended) to their current length (how many requests have been sent to that
// sequence ID since it started)
//
std::map<uint64_t, uint32_t> live_seq_ids_to_length;
// Map of sequence ID to how many requests have been received for it.
//
std::map<uint64_t, uint32_t> seq_ids_to_count;
// Map of sequence IDs to how many are "inflight" for that sequence ID
// (inflight means the request has been received, response has not been
// returned)
//
std::map<uint64_t, uint32_t> seq_ids_to_inflight_count;
// Maximum observed number of live sequences (sequences that have started
// and not ended)
//
uint32_t max_live_seq_count = 0;
// Maximum observed number of inflight requests for a sequence
//
uint32_t max_inflight_seq_count = 0;
std::vector<uint64_t> seq_lengths;
bool IsSeqLive(uint64_t seq_id)
{
return (
live_seq_ids_to_length.find(seq_id) != live_seq_ids_to_length.end());
}
void HandleSeqStart(uint64_t seq_id)
{
used_seq_ids.insert(seq_id);
live_seq_ids_to_length[seq_id] = 0;
if (live_seq_ids_to_length.size() > max_live_seq_count) {
max_live_seq_count = live_seq_ids_to_length.size();
}
}
void HandleSeqEnd(uint64_t seq_id)
{
uint32_t len = live_seq_ids_to_length[seq_id];
seq_lengths.push_back(len);
auto it = live_seq_ids_to_length.find(seq_id);
live_seq_ids_to_length.erase(it);
}
void HandleSeqRequest(uint64_t seq_id)
{
live_seq_ids_to_length[seq_id]++;
if (seq_ids_to_count.find(seq_id) == seq_ids_to_count.end()) {
seq_ids_to_count[seq_id] = 1;
} else {
seq_ids_to_count[seq_id]++;
}
if (seq_ids_to_inflight_count.find(seq_id) ==
seq_ids_to_inflight_count.end()) {
seq_ids_to_inflight_count[seq_id] = 1;
} else {
seq_ids_to_inflight_count[seq_id]++;
}
if (seq_ids_to_inflight_count[seq_id] > max_inflight_seq_count) {
max_inflight_seq_count = seq_ids_to_inflight_count[seq_id];
}
}
void Reset()
{
// Note that live_seq_ids_to_length is explicitly not reset here.
// This is because we always want to maintain the true status of
// live sequences
used_seq_ids.clear();
max_live_seq_count = 0;
seq_lengths.clear();
seq_ids_to_count.clear();
}
};
std::atomic<size_t> num_infer_calls{0};
std::atomic<size_t> num_async_infer_calls{0};
std::atomic<size_t> num_async_stream_infer_calls{0};
std::atomic<size_t> num_start_stream_calls{0};
std::atomic<size_t> num_active_infer_calls{0};
std::atomic<size_t> num_append_raw_calls{0};
std::atomic<size_t> num_set_shared_memory_calls{0};
// Struct tracking shared memory method calls
//
struct SharedMemoryStats {
std::atomic<size_t> num_unregister_all_shared_memory_calls{0};
std::atomic<size_t> num_register_system_shared_memory_calls{0};
std::atomic<size_t> num_register_cuda_shared_memory_calls{0};
std::atomic<size_t> num_register_cuda_memory_calls{0};
std::atomic<size_t> num_register_system_memory_calls{0};
std::atomic<size_t> num_create_shared_memory_region_calls{0};
std::atomic<size_t> num_map_shared_memory_calls{0};
std::atomic<size_t> num_close_shared_memory_calls{0};
std::atomic<size_t> num_unlink_shared_memory_region_calls{0};
std::atomic<size_t> num_unmap_shared_memory_calls{0};
// bool operator==(const SharedMemoryStats& lhs, const SharedMemoryStats&
// rhs)
bool operator==(const SharedMemoryStats& rhs) const
{
if (this->num_unregister_all_shared_memory_calls ==
rhs.num_unregister_all_shared_memory_calls &&
this->num_register_system_shared_memory_calls ==
rhs.num_register_system_shared_memory_calls &&
this->num_register_cuda_shared_memory_calls ==
rhs.num_register_cuda_shared_memory_calls &&
this->num_register_cuda_memory_calls ==
rhs.num_register_cuda_memory_calls &&
this->num_register_system_memory_calls ==
rhs.num_register_system_memory_calls &&
this->num_create_shared_memory_region_calls ==
rhs.num_create_shared_memory_region_calls &&
this->num_map_shared_memory_calls ==
rhs.num_map_shared_memory_calls &&
this->num_close_shared_memory_calls ==
rhs.num_close_shared_memory_calls &&
this->num_unlink_shared_memory_region_calls ==
rhs.num_unlink_shared_memory_region_calls &&
this->num_unmap_shared_memory_calls ==
rhs.num_unmap_shared_memory_calls) {
return true;
}
return false;
}
};
/// Determines how long the backend will delay before sending a "response".
/// If a single value vector is passed in, all responses will take that long.
/// If a list of values is passed in, then the mock backend will loop through
/// the values (and loop back to the start when it hits the end of the vector)
///
void SetDelays(std::vector<size_t> times)
{
response_delays_.clear();
for (size_t t : times) {
response_delays_.push_back(std::chrono::milliseconds{t});
}
}
/// Determines the return status of requests.
/// If a single value vector is passed in, all responses will return that
/// status. If a list of values is passed in, then the mock backend will loop
/// through the values (and loop back to the start when it hits the end of the
/// vector)
///
void SetReturnStatuses(std::vector<bool> statuses)
{
response_statuses_.clear();
for (bool success : statuses) {
if (success) {
response_statuses_.push_back(Error::Success);
} else {
response_statuses_.push_back(Error("Injected test error"));
}
}
}
std::chrono::milliseconds GetNextDelay()
{
std::lock_guard<std::mutex> lock(mtx_);
auto val = response_delays_[response_delays_index_];
response_delays_index_++;
if (response_delays_index_ == response_delays_.size()) {
response_delays_index_ = 0;
}
return val;
}
Error GetNextReturnStatus()
{
std::lock_guard<std::mutex> lock(mtx_);
auto val = response_statuses_[response_statuses_index_];
response_statuses_index_++;
if (response_statuses_index_ == response_statuses_.size()) {
response_statuses_index_ = 0;
}
return val;
}
bool start_stream_enable_stats_value{false};
std::vector<std::chrono::time_point<std::chrono::system_clock>>
request_timestamps;
SeqStatus sequence_status;
SharedMemoryStats memory_stats;
// Each entry in the top vector is a list of all inputs for an inference
// request. If there are multiple inputs due to batching and/or the model
// having multiple inputs, all of those from the same request will be in the
// same second level vector
std::vector<std::vector<TestRecordedInput>> recorded_inputs{};
void CaptureRequest(
ReqType type, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
{
num_active_infer_calls++;
std::lock_guard<std::mutex> lock(mtx_);
auto time = std::chrono::system_clock::now();
request_timestamps.push_back(time);
// Group all values across all inputs together into a single vector, and
// then record it
std::vector<TestRecordedInput> request_inputs;
for (const auto& input : inputs) {
auto recorded_inputs =
static_cast<const MockInferInput*>(input)->recorded_inputs_;
request_inputs.insert(
request_inputs.end(), recorded_inputs.begin(), recorded_inputs.end());
}
recorded_inputs.push_back(request_inputs);
UpdateCallCount(type);
UpdateSeqStatus(options);
AccumulateInferInputCalls(inputs);
}
void CaptureRequestEnd(const InferOptions& options)
{
num_active_infer_calls--;
if (options.sequence_id_ != 0) {
sequence_status.seq_ids_to_inflight_count[options.sequence_id_]--;
}
}
void CaptureStreamStart()
{
std::lock_guard<std::mutex> lock(mtx_);
num_start_stream_calls++;
}
void Reset()
{
std::lock_guard<std::mutex> lock(mtx_);
num_infer_calls = 0;
num_async_infer_calls = 0;
num_async_stream_infer_calls = 0;
num_start_stream_calls = 0;
request_timestamps.clear();
sequence_status.Reset();
}
private:
std::vector<std::chrono::milliseconds> response_delays_{
std::chrono::milliseconds{0}};
std::vector<Error> response_statuses_{Error::Success};
std::atomic<size_t> response_delays_index_{0};
std::atomic<size_t> response_statuses_index_{0};
std::mutex mtx_;
void UpdateCallCount(ReqType type)
{
if (type == ReqType::SYNC) {
num_infer_calls++;
} else if (type == ReqType::ASYNC) {
num_async_infer_calls++;
} else {
num_async_stream_infer_calls++;
}
}
void UpdateSeqStatus(const InferOptions& options)
{
// Seq ID of 0 is reserved for "not a sequence"
//
if (options.sequence_id_ != 0) {
// If a sequence ID is not live, it must be starting
if (!sequence_status.IsSeqLive(options.sequence_id_)) {
REQUIRE(options.sequence_start_ == true);
}
// If a new sequence is starting, that sequence ID must not already be
// live
if (options.sequence_start_ == true) {
REQUIRE(sequence_status.IsSeqLive(options.sequence_id_) == false);
sequence_status.HandleSeqStart(options.sequence_id_);
}
sequence_status.HandleSeqRequest(options.sequence_id_);
// If a sequence is ending, it must be live
if (options.sequence_end_) {
REQUIRE(sequence_status.IsSeqLive(options.sequence_id_) == true);
sequence_status.HandleSeqEnd(options.sequence_id_);
}
}
}
void AccumulateInferInputCalls(const std::vector<InferInput*>& inputs)
{
for (const auto& input : inputs) {
const MockInferInput* mock_input =
static_cast<const MockInferInput*>(input);
num_append_raw_calls += mock_input->append_raw_calls_;
num_set_shared_memory_calls += mock_input->set_shared_memory_calls_;
}
}
};
/// Mock implementation of ClientBackend interface
///
class NaggyMockClientBackend : public ClientBackend {
public:
NaggyMockClientBackend(std::shared_ptr<MockClientStats> stats) : stats_(stats)
{
ON_CALL(*this, AsyncStreamInfer(testing::_, testing::_, testing::_))
.WillByDefault(
[this](
const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs)
-> Error {
stats_->CaptureRequest(
MockClientStats::ReqType::ASYNC_STREAM, options, inputs,
outputs);
LaunchAsyncMockRequest(options, stream_callback_);
return stats_->GetNextReturnStatus();
});
}
MOCK_METHOD(
Error, ModelConfig,
(rapidjson::Document*, const std::string&, const std::string&),
(override));
MOCK_METHOD(
Error, AsyncStreamInfer,
(const InferOptions&, const std::vector<InferInput*>&,
const std::vector<const InferRequestedOutput*>&),
(override));
Error Infer(
InferResult** result, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override
{
stats_->CaptureRequest(
MockClientStats::ReqType::SYNC, options, inputs, outputs);
std::this_thread::sleep_for(stats_->GetNextDelay());
local_completed_req_count_++;
stats_->CaptureRequestEnd(options);
return stats_->GetNextReturnStatus();
}
Error AsyncInfer(
OnCompleteFn callback, const InferOptions& options,
const std::vector<InferInput*>& inputs,
const std::vector<const InferRequestedOutput*>& outputs) override
{
stats_->CaptureRequest(
MockClientStats::ReqType::ASYNC, options, inputs, outputs);
LaunchAsyncMockRequest(options, callback);
return stats_->GetNextReturnStatus();
}
Error StartStream(OnCompleteFn callback, bool enable_stats)
{
stats_->CaptureStreamStart();
stats_->start_stream_enable_stats_value = enable_stats;
stream_callback_ = callback;
return stats_->GetNextReturnStatus();
}
Error ClientInferStat(InferStat* infer_stat) override
{
infer_stat->completed_request_count = local_completed_req_count_;
return Error::Success;
}
Error UnregisterAllSharedMemory() override
{
stats_->memory_stats.num_unregister_all_shared_memory_calls++;
return Error::Success;
}
Error RegisterSystemSharedMemory(
const std::string& name, const std::string& key,
const size_t byte_size) override
{
stats_->memory_stats.num_register_system_shared_memory_calls++;
return Error::Success;
}
Error RegisterCudaSharedMemory(
const std::string& name, const cudaIpcMemHandle_t& handle,
const size_t byte_size) override
{
stats_->memory_stats.num_register_cuda_shared_memory_calls++;
return Error::Success;
}
Error RegisterCudaMemory(
const std::string& name, void* handle, const size_t byte_size) override
{
stats_->memory_stats.num_register_cuda_memory_calls++;
return Error::Success;
}
Error RegisterSystemMemory(
const std::string& name, void* memory_ptr,
const size_t byte_size) override
{
stats_->memory_stats.num_register_system_memory_calls++;
return Error::Success;
}
Error CreateSharedMemoryRegion(
std::string shm_key, size_t byte_size, int* shm_fd) override
{
stats_->memory_stats.num_create_shared_memory_region_calls++;
return Error::Success;
}
Error MapSharedMemory(
int shm_fd, size_t offset, size_t byte_size, void** shm_addr) override
{
stats_->memory_stats.num_map_shared_memory_calls++;
return Error::Success;
}
Error CloseSharedMemory(int shm_fd) override
{
stats_->memory_stats.num_close_shared_memory_calls++;
return Error::Success;
}
Error UnlinkSharedMemoryRegion(std::string shm_key) override
{
stats_->memory_stats.num_unlink_shared_memory_region_calls++;
return Error::Success;
}
Error UnmapSharedMemory(void* shm_addr, size_t byte_size) override
{
stats_->memory_stats.num_unmap_shared_memory_calls++;
return Error::Success;
}
OnCompleteFn stream_callback_;
private:
void LaunchAsyncMockRequest(const InferOptions options, OnCompleteFn callback)
{
std::thread([this, options, callback]() {
std::this_thread::sleep_for(stats_->GetNextDelay());
local_completed_req_count_++;
InferResult* result = new MockInferResult(options);
callback(result);
stats_->CaptureRequestEnd(options);
}).detach();
}
// Total count of how many requests this client has handled and finished
size_t local_completed_req_count_ = 0;
std::shared_ptr<MockClientStats> stats_;
};
using MockClientBackend = testing::NiceMock<NaggyMockClientBackend>;
/// Mock factory that always creates a MockClientBackend instead
/// of a real backend
///
class MockClientBackendFactory : public ClientBackendFactory {
public:
MockClientBackendFactory(std::shared_ptr<MockClientStats> stats)
{
stats_ = stats;
}
Error CreateClientBackend(std::unique_ptr<ClientBackend>* backend) override
{
std::unique_ptr<MockClientBackend> mock_backend(
new MockClientBackend(stats_));
*backend = std::move(mock_backend);
return Error::Success;
}
private:
std::shared_ptr<MockClientStats> stats_;
};
}}} // namespace triton::perfanalyzer::clientbackend
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
cmake_minimum_required (VERSION 3.18)
FetchContent_Declare(tensorflow-serving-repo
PREFIX tensorflow-serving-rep
)
FetchContent_GetProperties(tensorflow-serving-repo)
if(NOT tensorflow-serving-repo_POPULATED)
FetchContent_Populate(tensorflow-serving-repo
GIT_REPOSITORY "https://github.com/tensorflow/serving.git"
GIT_TAG "2.3.0"
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/tensorflow-serving-repo/src/tensorflow_serving"
)
endif()
FetchContent_Declare(tensorflow-repo
PREFIX tensorflow-repo
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/tensorflow-repo/src/tensorflow"
)
FetchContent_GetProperties(tensorflow-repo)
if(NOT tensorflow-repo_POPULATED)
FetchContent_Populate(tensorflow-repo
GIT_REPOSITORY "https://github.com/tensorflow/tensorflow.git"
GIT_TAG "v2.3.0"
SOURCE_DIR "${CMAKE_CURRENT_BINARY_DIR}/tensorflow-repo/src/tensorflow"
)
endif()
set(TENSORFLOW_PATH ${CMAKE_CURRENT_BINARY_DIR}/tensorflow-repo/src/tensorflow)
set(TFSERVE_PATH ${CMAKE_CURRENT_BINARY_DIR}/tensorflow-serving-repo/src/tensorflow_serving)
# Copy the repos to a proto staging area.
file(MAKE_DIRECTORY ${CMAKE_BINARY_DIR}/protos)
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory ${TENSORFLOW_PATH}/tensorflow
${CMAKE_BINARY_DIR}/protos/tensorflow)
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_directory ${TFSERVE_PATH}/tensorflow_serving
${CMAKE_BINARY_DIR}/protos/tensorflow_serving)
# Protobuf compiler dependency.
include(CompileProto.cmake)
# Protobuf sources of the TensorFlow Serving to be compiled without a gRPC plugin.
file(GLOB_RECURSE TFSERVING_PROTOS ${CMAKE_BINARY_DIR}/protos/tensorflow_serving/*.proto)
file(GLOB TF_EXAMPLE_PROTOS ${CMAKE_BINARY_DIR}/protos/tensorflow/core/example/*.proto)
file(GLOB TF_FW_PROTOS ${CMAKE_BINARY_DIR}/protos/tensorflow/core/framework/*.proto)
file(GLOB TF_PROTOBUF_PROTOS ${CMAKE_BINARY_DIR}/protos/tensorflow/core/protobuf/*.proto)
# This is a dirty hack to prevent unnecessary leaking dependency
list(FILTER TF_PROTOBUF_PROTOS EXCLUDE REGEX "autotuning.proto$|conv_autotuning.proto$")
# Compiling CPP sources from proto files.
compile_proto(0 "${CMAKE_BINARY_DIR}/protos" "${CMAKE_CURRENT_BINARY_DIR}/compiled" PB_SOURCES PB_HEADERS
${TFSERVING_PROTOS} ${TF_EXAMPLE_PROTOS} ${TF_FW_PROTOS} ${TF_PROTOBUF_PROTOS})
# Compiling CPP sources with gRPC plugin.
compile_proto(1 "${CMAKE_BINARY_DIR}/protos" "${CMAKE_CURRENT_BINARY_DIR}/compiled" PB_GRPC_SOURCES PB_GRPC_HEADERS
${CMAKE_BINARY_DIR}/protos/tensorflow_serving/apis/prediction_service.proto)
set(
TFS_CLIENT_BACKEND_SRCS
tfserve_client_backend.cc
tfserve_infer_input.cc
tfserve_grpc_client.cc
${PB_SOURCES}
${PB_GRPC_SOURCES}
)
set(
TFS_CLIENT_BACKEND_HDRS
tfserve_client_backend.h
tfserve_infer_input.h
tfserve_grpc_client.h
${PB_HEADERS}
${PB_GRPC_HEADERS}
)
add_library(
tfs-client-backend-library EXCLUDE_FROM_ALL OBJECT
${TFS_CLIENT_BACKEND_SRCS}
${TFS_CLIENT_BACKEND_HDRS}
)
target_include_directories(tfs-client-backend-library PUBLIC ${CMAKE_CURRENT_BINARY_DIR}/compiled)
target_link_libraries(
tfs-client-backend-library
PUBLIC gRPC::grpc++
PUBLIC gRPC::grpc
PUBLIC protobuf::libprotobuf
PUBLIC grpcclient_static
)
if(${TRITON_ENABLE_GPU})
target_include_directories(tfs-client-backend-library PUBLIC ${CUDA_INCLUDE_DIRS})
target_link_libraries(tfs-client-backend-library PRIVATE ${CUDA_LIBRARIES})
endif() # TRITON_ENABLE_GPU
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment