#!/usr/bin/env python3 # Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions # are met: # * Redistributions of source code must retain the above copyright # notice, this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright # notice, this list of conditions and the following disclaimer in the # documentation and/or other materials provided with the distribution. # * Neither the name of NVIDIA CORPORATION nor the names of its # contributors may be used to endorse or promote products derived # from this software without specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY # EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR # PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR # CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, # EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, # PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR # PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. import grpc from tritonclient.grpc import service_pb2 from tritonclient.utils import * def get_error_grpc(rpc_error): """Convert a gRPC error to an InferenceServerException. Parameters ---------- rpc_error : grpc.RpcError The gRPC error Returns ------- InferenceServerException """ return InferenceServerException( msg=rpc_error.details(), status=str(rpc_error.code()), debug_details=rpc_error.debug_error_string(), ) def raise_error_grpc(rpc_error): """Raise an InferenceServerException from a gRPC error. Parameters ---------- rpc_error : grpc.RpcError The gRPC error Raises ------- InferenceServerException """ raise get_error_grpc(rpc_error) from None def _get_inference_request( model_name, inputs, model_version, request_id, outputs, sequence_id, sequence_start, sequence_end, priority, timeout, parameters, ): request = service_pb2.ModelInferRequest() request.model_name = model_name request.model_version = model_version if request_id != "": request.id = request_id for infer_input in inputs: request.inputs.extend([infer_input._get_tensor()]) if infer_input._get_content() is not None: request.raw_input_contents.extend([infer_input._get_content()]) if outputs is not None: for infer_output in outputs: request.outputs.extend([infer_output._get_tensor()]) if sequence_id != 0 and sequence_id != "": if isinstance(sequence_id, str): request.parameters["sequence_id"].string_param = sequence_id else: request.parameters["sequence_id"].int64_param = sequence_id request.parameters["sequence_start"].bool_param = sequence_start request.parameters["sequence_end"].bool_param = sequence_end if priority != 0: request.parameters["priority"].uint64_param = priority if timeout is not None: request.parameters["timeout"].int64_param = timeout if parameters: for key, value in parameters.items(): if ( key == "sequence_id" or key == "sequence_start" or key == "sequence_end" or key == "priority" or key == "binary_data_output" ): raise_error( f'Parameter "{key}" is a reserved parameter and cannot be specified.' ) else: if isinstance(value, str): request.parameters[key].string_param = value elif isinstance(value, bool): request.parameters[key].bool_param = value elif isinstance(value, int): request.parameters[key].int64_param = value else: raise_error( f'The parameter datatype "{type(value)}" for key "{key}" is not supported.' ) return request def _grpc_compression_type(algorithm_str): if algorithm_str is None: return grpc.Compression.NoCompression elif algorithm_str.lower() == "deflate": return grpc.Compression.Deflate elif algorithm_str.lower() == "gzip": return grpc.Compression.Gzip print( "The provided client-side compression algorithm is not supported... " "using no compression" ) return grpc.Compression.NoCompression