// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using System; using System.Runtime.InteropServices; using System.Collections.Generic; using System.IO; using System.Linq; using Microsoft.ML.OnnxRuntime.Tensors; using System.Buffers; namespace Microsoft.ML.OnnxRuntime { /// /// Represents an Inference Session on an ONNX Model. /// This is a IDisposable class and it must be disposed of /// using either a explicit call to Dispose() method or /// a pattern of using() block. If this is a member of another /// class that class must also become IDisposable and it must /// dispose of InferfenceSession in its Dispose() method. /// public class InferenceSession : IDisposable { /// /// A pointer to a underlying native instance of OrtSession /// private IntPtr _nativeHandle; /// /// Dictionary that represents input metadata /// private Dictionary _inputMetadata; /// /// Dictionary that represent output metadata /// private Dictionary _outputMetadata; /// /// Dictionary that represents overridableInitializers metadata /// private Dictionary _overridableInitializerMetadata; private SessionOptions _builtInSessionOptions = null; private RunOptions _builtInRunOptions = null; private ModelMetadata _modelMetadata = null; private bool _disposed = false; private ulong _profilingStartTimeNs = 0; #region Public API /// /// Constructs an InferenceSession from a model file /// /// public InferenceSession(string modelPath) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(modelPath, _builtInSessionOptions); } /// /// Constructs an InferenceSession from a model file and it will use /// the provided pre-packed weights container to store and share pre-packed buffers /// of shared initializers across sessions if any. /// /// Model path /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(string modelPath, PrePackedWeightsContainer prepackedWeightsContainer) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(modelPath, _builtInSessionOptions, prepackedWeightsContainer); } /// /// Constructs an InferenceSession from a model file, with some additional session options /// /// /// public InferenceSession(string modelPath, SessionOptions options) { Init(modelPath, options); } /// /// Constructs an InferenceSession from a model file, with some additional session options /// and it will use the provided pre-packed weights container to store and share pre-packed buffers /// of shared initializers across sessions if any. /// /// Model path /// Session options /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(string modelPath, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer) { Init(modelPath, options, prepackedWeightsContainer); } /// /// Constructs an InferenceSession from a model data in byte array /// /// public InferenceSession(byte[] model) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(model, _builtInSessionOptions); } /// /// Constructs an InferenceSession from a model data (in byte array) and it will use /// the provided pre-packed weights container to store and share pre-packed buffers /// of shared initializers across sessions if any. /// /// Model as byte array /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(byte[] model, PrePackedWeightsContainer prepackedWeightsContainer) { _builtInSessionOptions = new SessionOptions(); // need to be disposed Init(model, _builtInSessionOptions, prepackedWeightsContainer); } /// /// Constructs an InferenceSession from a model data in byte array, with some additional session options /// /// /// public InferenceSession(byte[] model, SessionOptions options) { Init(model, options); } /// /// Constructs an InferenceSession from a model data (in byte array) with some additional /// session options and it will use the provided pre-packed weights container to store /// and share pre-packed buffers of shared initializers across sessions if any. /// /// Model as byte array /// Session Options /// Instance of PrepackedWeightsContainer. /// Lifetime of 'prepackedWeightsContainer' must be /// managed by the user and it must outlive any sessions reliant on it public InferenceSession(byte[] model, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer) { Init(model, options, prepackedWeightsContainer); } /// /// Meta data regarding the input nodes, keyed by input names /// public IReadOnlyDictionary InputMetadata { get { return _inputMetadata; } } /// /// Metadata regarding the output nodes, keyed by output names /// public IReadOnlyDictionary OutputMetadata { get { return _outputMetadata; } } /// /// Metadata regarding the overridable initializers, keyed by node names /// public IReadOnlyDictionary OverridableInitializerMetadata { get { return _overridableInitializerMetadata; } } /// /// Runs the loaded model for the given inputs, and fetches all the outputs. /// /// specify a collection of that indicates the input values. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs) { string[] outputNames = new string[_outputMetadata.Count]; _outputMetadata.Keys.CopyTo(outputNames, 0); return Run(inputs, outputNames); } /// /// Runs the loaded model for the given inputs, and fetches the outputs specified in . /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames) { return Run(inputs, outputNames, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs, and fetches the specified outputs in . Uses the given RunOptions for this run. /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run(IReadOnlyCollection inputs, IReadOnlyCollection outputNames, RunOptions options) { using (var cleanupList = new DisposableList()) { var inputNamesArray = ConvertNamesToUtf8(inputs, v => v.Name, cleanupList); var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList); return CreateDisposableResult(ortValues, outputNames); } } /// /// Runs the loaded model for the given inputs, and fetches all the outputs. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues) { string[] outputNames = new string[_outputMetadata.Count]; _outputMetadata.Keys.CopyTo(outputNames, 0); return Run(inputNames, inputValues, outputNames, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs, and fetches the outputs specified in . /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames) { return Run(inputNames, inputValues, outputNames, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs, and fetches the specified outputs in . Uses the given RunOptions for this run. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names to fetch. /// /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, RunOptions options) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } using (var cleanupList = new DisposableList()) { var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); var ortValues = RunImpl(options, inputNamesArray, inputValuesArray, outputNamesArray, cleanupList); return CreateDisposableResult(ortValues, outputNames); } } /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to accept the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues) { Run(inputNames, inputValues, outputNames, outputValues, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to accept the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues, RunOptions options) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } if (outputNames.Count != outputValues.Count) { throw new ArgumentException($"Length of {nameof(outputNames)} ({outputNames.Count}) must match that of {nameof(outputValues)} ({outputValues.Count})."); } using (var cleanupList = new DisposableList()) { // prepare inputs var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputNames.Count, outputNamesArray, (UIntPtr)outputNames.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } } /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputs) { Run(inputs, outputs, _builtInRunOptions); } /// /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputs, RunOptions options) { using (var cleanupList = new DisposableList()) { var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList); var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList); var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputs.Count, outputNamesArray, (UIntPtr)outputs.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } } /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues) { Run(inputs, outputNames, outputValues, _builtInRunOptions); } /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of that indicates the input values. /// Specify a collection of string that indicates the output names. Should match . /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputs, IReadOnlyCollection outputNames, IReadOnlyCollection outputValues, RunOptions options) { if (outputNames.Count != outputValues.Count) { throw new ArgumentException($"Length of {nameof(outputNames)} ({outputNames.Count}) must match that of {nameof(outputValues)} ({outputValues.Count})."); } using (var cleanupList = new DisposableList()) { // prepare inputs var inputNamesArray = ConvertNamesToUtf8(inputs, i => i.Name, cleanupList); var inputValuesArray = GetOrtValuesHandles(inputs, cleanupList); // prepare outputs var outputNamesArray = ConvertNamesToUtf8(outputNames, n => n, cleanupList); var outputValuesArray = GetOrtValuesHandles(outputValues, false); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputs.Count, outputNamesArray, (UIntPtr)outputNames.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } } /// /// /// Runs the loaded model for the given inputs and outputs. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputs) { Run(inputNames, inputValues, outputs, _builtInRunOptions); } /// /// /// Runs the loaded model for the given inputs and outputs. Uses the given RunOptions for this run. /// /// Outputs need to be created with correct type and dimension to receive the fetched data. /// /// Specify a collection of string that indicates the input names. Should match . /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. /// public void Run( IReadOnlyCollection inputNames, IReadOnlyCollection inputValues, IReadOnlyCollection outputs, RunOptions options) { if (inputNames.Count != inputValues.Count) { throw new ArgumentException($"Length of {nameof(inputNames)} ({inputNames.Count}) must match that of {nameof(inputValues)} ({inputValues.Count})."); } using (var cleanupList = new DisposableList()) { // prepare inputs var inputNamesArray = ConvertNamesToUtf8(inputNames, n => n, cleanupList); var inputValuesArray = GetOrtValuesHandles(inputValues, true); // prepare outputs var outputNamesArray = ConvertNamesToUtf8(outputs, o => o.Name, cleanupList); var outputValuesArray = GetOrtValuesHandles(outputs, cleanupList); NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNamesArray, inputValuesArray, (UIntPtr)inputNames.Count, outputNamesArray, (UIntPtr)outputs.Count, outputValuesArray /* pointers to Pre-allocated OrtValue instances */ )); } } /// /// Create OrtIoBinding instance to bind pre-allocated buffers /// to input/output /// /// A new instance of OrtIoBinding public OrtIoBinding CreateIoBinding() { return new OrtIoBinding(this); } /// /// This method runs inference on the OrtIoBinding instance /// The method does not return anything. This is a lightweight version of /// RunWithBindingAndNames(). When you bind pre-allocated buffers to the output values /// you may not want to fetch the outputs since you already have access to them so you can spare /// the expense of fetching them and pairing with names. /// You can still fetch the outputs by calling OrtIOBinding.GetOutputValues() /// /// runOptions /// ioBinding instance to use public void RunWithBinding(RunOptions runOptions, OrtIoBinding ioBinding) { NativeApiStatus.VerifySuccess(NativeMethods.OrtRunWithBinding(Handle, runOptions.Handle, ioBinding.Handle)); } /// /// This method return a collection of DisposableNamedOnnxValue as in other interfaces /// Query names from OrtIoBinding object and pair then with the array of OrtValues returned /// from OrtIoBinding.GetOutputValues() /// /// /// RunOptions /// OrtIoBinding instance with bindings /// optional parameter. If you already know the names of the outputs you can save a native /// call to retrieve output names. They will be paired with the returned OrtValues and combined into DisposbleNamedOnnxValues. /// Otherwise, the method will retrieve output names from the OrtIoBinding instance. /// It is an error if you supply a different number of names than the returned outputs /// A disposable collection of DisposableNamedOnnxValue that encapsulate output OrtValues public IDisposableReadOnlyCollection RunWithBindingAndNames(RunOptions runOptions, OrtIoBinding ioBinding, string[] names = null) { NativeApiStatus.VerifySuccess(NativeMethods.OrtRunWithBinding(Handle, runOptions.Handle, ioBinding.Handle)); using (var ortValues = ioBinding.GetOutputValues()) { string[] outputNames = names; if (outputNames == null) { outputNames = ioBinding.GetOutputNames(); } if (outputNames.Length != ortValues.Count) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Number of specified names: " + names.Length + " does not match the output number: " + ortValues.Count); } var result = new DisposableList(outputNames.Length); try { for (int i = 0; i < outputNames.Length; ++i) { var ortValue = ortValues.ElementAt(i); result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames[i], ortValue)); } } catch (Exception) { result.Dispose(); throw; } return result; } } /// /// Ends profiling for the session. /// /// Returns the profile file name. public string EndProfiling() { IntPtr nameHandle = IntPtr.Zero; var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionEndProfiling(_nativeHandle, allocator.Pointer, out nameHandle)); using (var allocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { return NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); } } // Delegate for string extraction from an arbitrary input/output object private delegate string NameExtractor(TInput input); /// /// Run helper /// /// names to convert to zero terminated utf8 and pin /// list to add pinned memory to for later disposal /// private IntPtr[] ConvertNamesToUtf8(IReadOnlyCollection inputs, NameExtractor extractor, DisposableList cleanupList) { var result = new IntPtr[inputs.Count]; for (int i = 0; i < inputs.Count; ++i) { var name = extractor(inputs.ElementAt(i)); var utf8Name = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name); var pinnedHandle = new PinnedGCHandle(GCHandle.Alloc(utf8Name, GCHandleType.Pinned)); result[i] = pinnedHandle.Pointer; cleanupList.Add(pinnedHandle); } return result; } /// /// This function obtains ortValues for NamedOnnxValue. /// The problem with NamedOnnxValue is that it does not contain any Onnx (OrtValue) /// so calling ToOrtValue creates a new instance of OrtValue that needs to be disposed. /// The deriving object DisposableNamedValue actually contains and owns OrtValue and it returns /// it. /// /// /// /// private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, DisposableList cleanupList) { IntPtr[] result = new IntPtr[values.Count]; for (int inputIndex = 0; inputIndex < values.Count; ++inputIndex) { var input = values.ElementAt(inputIndex); MemoryHandle? memHandle; var ortValue = input.ToOrtValue(out memHandle); if (memHandle.HasValue) { cleanupList.Add(memHandle); } cleanupList.Add(ortValue); result[inputIndex] = ortValue.Handle; } return result; } private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, bool input) { var valuesArray = new IntPtr[values.Count]; for (int index = 0; index < values.Count; ++index) { var v = values.ElementAt(index); if (!input && v.ElementType == Tensors.TensorElementType.String) { throw new NotSupportedException("Using string type FixedBufferOnnxValue in outputs is not supported."); } valuesArray[index] = v.Value.Handle; } return valuesArray; } private DisposableList RunImpl(RunOptions options, IntPtr[] inputNames, IntPtr[] inputValues, IntPtr[] outputNames, DisposableList cleanupList) { var ortValues = new DisposableList(outputNames.Length); cleanupList.Add(ortValues); IntPtr[] outputValuesArray = new IntPtr[outputNames.Length]; NativeApiStatus.VerifySuccess(NativeMethods.OrtRun( _nativeHandle, options.Handle, inputNames, inputValues, (UIntPtr)inputNames.Length, outputNames, (UIntPtr)outputNames.Length, outputValuesArray /* Empty array is passed in to receive output OrtValue pointers */ )); foreach (var v in outputValuesArray) { ortValues.Add(new OrtValue(v)); } return ortValues; } IDisposableReadOnlyCollection CreateDisposableResult(List ortValues, IReadOnlyCollection outputNames) { var result = new DisposableList(outputNames.Count); try { for (int i = 0; i < ortValues.Count; i++) { var ortValue = ortValues[i]; result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(outputNames.ElementAt(i), ortValue)); } } catch (OnnxRuntimeException) { result.Dispose(); throw; } return result; } /// /// This property queries model metadata, constructs /// an instance of ModelMetadata and caches it /// /// Instance of ModelMetdata public ModelMetadata ModelMetadata { get { if (_modelMetadata != null) { return _modelMetadata; } _modelMetadata = new ModelMetadata(this); return _modelMetadata; } } /// /// Return the nanoseconds of profiling's start time /// On some platforms, this timer may not be as precise as nanoseconds /// For instance, on Windows and MacOS, the precision will be ~100ns /// public ulong ProfilingStartTimeNs { get { return _profilingStartTimeNs; } } #endregion #region private methods private void Init(string modelPath, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer = null) { var envHandle = OrtEnv.Handle; var session = IntPtr.Zero; if (prepackedWeightsContainer == null) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSession(envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), options.Handle, out session)); } else { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionWithPrepackedWeightsContainer( envHandle, NativeOnnxValueHelper.GetPlatformSerializedString(modelPath), options.Handle, prepackedWeightsContainer.Pointer, out session)); } InitWithSessionHandle(session, options); } private void Init(byte[] modelData, SessionOptions options, PrePackedWeightsContainer prepackedWeightsContainer = null) { var envHandle = OrtEnv.Handle; var session = IntPtr.Zero; if (prepackedWeightsContainer == null) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArray(envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, out session)); } else { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionFromArrayWithPrepackedWeightsContainer( envHandle, modelData, (UIntPtr)modelData.Length, options.Handle, prepackedWeightsContainer.Pointer, out session)); } InitWithSessionHandle(session, options); } /// /// Initializes the session object with a native session handle /// /// Value of a native session object /// Session options private void InitWithSessionHandle(IntPtr session, SessionOptions options) { _nativeHandle = session; try { // Initialize input/output metadata _inputMetadata = new Dictionary(); _outputMetadata = new Dictionary(); _overridableInitializerMetadata = new Dictionary(); // get input count UIntPtr inputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputCount(_nativeHandle, out inputCount)); // get all the input names and metadata for (ulong i = 0; i < (ulong)inputCount; i++) { var iname = GetInputName(i); _inputMetadata[iname] = GetInputMetadata(i); } // get output count UIntPtr outputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputCount(_nativeHandle, out outputCount)); // get all the output names and metadata for (ulong i = 0; i < (ulong)outputCount; i++) { _outputMetadata[GetOutputName(i)] = GetOutputMetadata(i); } // get overridable initializer count UIntPtr initilaizerCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerCount(_nativeHandle, out initilaizerCount)); // get all the overridable initializer names and metadata for (ulong i = 0; i < (ulong)initilaizerCount; i++) { _overridableInitializerMetadata[GetOverridableInitializerName(i)] = GetOverridableInitializerMetadata(i); } // set profiling's start time UIntPtr startTime = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetProfilingStartTimeNs(_nativeHandle, out startTime)); _profilingStartTimeNs = (ulong)startTime; } catch (OnnxRuntimeException) { if (_nativeHandle != IntPtr.Zero) { NativeMethods.OrtReleaseSession(_nativeHandle); _nativeHandle = IntPtr.Zero; } throw; } _builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call } private string GetOutputName(ulong index) { var allocator = OrtAllocator.DefaultInstance; IntPtr nameHandle = IntPtr.Zero; string str = null; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out nameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); } return str; } private string GetInputName(ulong index) { string str = null; var allocator = OrtAllocator.DefaultInstance; IntPtr nameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out nameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); } return str; } private string GetOverridableInitializerName(ulong index) { string str = null; var allocator = OrtAllocator.DefaultInstance; IntPtr nameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out nameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); } return str; } private NodeMetadata GetInputMetadata(ulong index) { IntPtr typeInfo = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetInputTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); } finally { NativeMethods.OrtReleaseTypeInfo(typeInfo); } } private NodeMetadata GetOutputMetadata(ulong index) { IntPtr typeInfo = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOutputTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); } finally { NativeMethods.OrtReleaseTypeInfo(typeInfo); } } private NodeMetadata GetOverridableInitializerMetadata(ulong index) { IntPtr typeInfo = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetOverridableInitializerTypeInfo(_nativeHandle, (UIntPtr)index, out typeInfo)); try { return GetMetadataFromTypeInfo(typeInfo); } finally { NativeMethods.OrtReleaseTypeInfo(typeInfo); } } internal static NodeMetadata GetMetadataFromTypeInfo(IntPtr typeInfo) { OnnxValueType valueType; { IntPtr valType; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetOnnxTypeFromTypeInfo(typeInfo, out valType)); valueType = (OnnxValueType)valType; } if (valueType != OnnxValueType.ONNX_TYPE_TENSOR && valueType != OnnxValueType.ONNX_TYPE_SPARSETENSOR) { return new NodeMetadata(valueType, new int[] { }, new string[] { }, typeof(NamedOnnxValue)); } // This should not be released IntPtr tensorInfo; NativeApiStatus.VerifySuccess(NativeMethods.OrtCastTypeInfoToTensorInfo(typeInfo, out tensorInfo)); //(IntPtr)(int)(uint) // Convert the newly introduced OrtTypeInfo* to the older OrtTypeAndShapeInfo* if (tensorInfo == IntPtr.Zero) return null; TensorElementType type; { IntPtr el_type; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorElementType(tensorInfo, out el_type)); type = (TensorElementType)el_type; } Type dotnetType = null; int width = 0; if (!TensorElementTypeConverter.GetTypeAndWidth(type, out dotnetType, out width)) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Unable to query type information for data type: " + type.ToString()); } UIntPtr numDimensions; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensionsCount(tensorInfo, out numDimensions)); long[] dimensions = new long[(int)numDimensions]; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetDimensions(tensorInfo, dimensions, numDimensions)); int[] intDimensions = new int[(int)numDimensions]; for (var i = 0; i < (long)numDimensions; i++) { intDimensions[i] = (int)dimensions[i]; } IntPtr[] dimensionNamePtrs = new IntPtr[(int)numDimensions]; NativeApiStatus.VerifySuccess( NativeMethods.OrtGetSymbolicDimensions(tensorInfo, dimensionNamePtrs, numDimensions)); string[] symbolicDimensions = new string[(int)numDimensions]; for (var i = 0; i < (int)numDimensions; i++) { symbolicDimensions[i] = NativeOnnxValueHelper.StringFromNativeUtf8(dimensionNamePtrs[i]); } return new NodeMetadata(valueType, intDimensions, symbolicDimensions, dotnetType); } /// /// Other classes access /// internal IntPtr Handle { get { return _nativeHandle; } } #endregion #region IDisposable /// /// Finalizer. to cleanup session in case it runs /// and the user forgets to Dispose() of the session /// ~InferenceSession() { Dispose(false); } /// /// IDisposable implementation /// public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } /// /// IDisposable implementation /// /// true if invoked from Dispose() method protected virtual void Dispose(bool disposing) { if (_disposed) { return; } if (disposing) { // cleanup managed resources if (_builtInSessionOptions != null) { _builtInSessionOptions.Dispose(); _builtInSessionOptions = null; } if (_builtInRunOptions != null) { _builtInRunOptions.Dispose(); _builtInRunOptions = null; } } // cleanup unmanaged resources if (_nativeHandle != IntPtr.Zero) { NativeMethods.OrtReleaseSession(_nativeHandle); _nativeHandle = IntPtr.Zero; } _disposed = true; } #endregion } /// /// Resembles type and shape information of session-graph nodes, used for communicating the shape/type of input/output nodes /// public class NodeMetadata { internal NodeMetadata(OnnxValueType onnxValueType, int[] dimensions, string[] symbolicDimensions, Type type) { OnnxValueType = onnxValueType; Dimensions = dimensions; SymbolicDimensions = symbolicDimensions; ElementType = type; } /// /// Type value of the node /// /// A value of OnnxValueType enum public OnnxValueType OnnxValueType { get; } /// /// Shape /// /// Array of dimensions public int[] Dimensions { get; } /// /// Symbolic dimensions /// /// Array of symbolic dimensions if present. public string[] SymbolicDimensions { get; } /// /// .NET type that corresponds to this Node. /// /// System.Type public System.Type ElementType { get; } /// /// Whether it is a Tensor /// /// currently always returns true public bool IsTensor { get { return true; // currently only Tensor nodes are supported } } } /// /// A class that queries and caches model metadata and exposes /// it as properties /// public class ModelMetadata { private string _producerName; private string _graphName; private string _domain; private string _description; private string _graphDescription; private long _version; private Dictionary _customMetadataMap = new Dictionary(); internal ModelMetadata(InferenceSession session) { IntPtr modelMetadataHandle = IntPtr.Zero; var allocator = OrtAllocator.DefaultInstance; // Get the native ModelMetadata instance associated with the InferenceSession NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionGetModelMetadata(session.Handle, out modelMetadataHandle)); try { // Process producer name IntPtr producerNameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetProducerName(modelMetadataHandle, allocator.Pointer, out producerNameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, producerNameHandle, 0)) { _producerName = NativeOnnxValueHelper.StringFromNativeUtf8(producerNameHandle); } // Process graph name IntPtr graphNameHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphName(modelMetadataHandle, allocator.Pointer, out graphNameHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, graphNameHandle, 0)) { _graphName = NativeOnnxValueHelper.StringFromNativeUtf8(graphNameHandle); } // Process domain IntPtr domainHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDomain(modelMetadataHandle, allocator.Pointer, out domainHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, domainHandle, 0)) { _domain = NativeOnnxValueHelper.StringFromNativeUtf8(domainHandle); } // Process description IntPtr descriptionHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetDescription(modelMetadataHandle, allocator.Pointer, out descriptionHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, descriptionHandle, 0)) { _description = NativeOnnxValueHelper.StringFromNativeUtf8(descriptionHandle); } // Process graph description IntPtr graphDescriptionHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetGraphDescription(modelMetadataHandle, allocator.Pointer, out graphDescriptionHandle)); using (var ortAllocation = new OrtMemoryAllocation(allocator, graphDescriptionHandle, 0)) { _graphDescription = NativeOnnxValueHelper.StringFromNativeUtf8(graphDescriptionHandle); } // Process version NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetVersion(modelMetadataHandle, out _version)); // Process CustomMetadata Map IntPtr customMetadataMapKeysHandle = IntPtr.Zero; long numKeys; NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataGetCustomMetadataMapKeys(modelMetadataHandle, allocator.Pointer, out customMetadataMapKeysHandle, out numKeys)); // We have received an array of null terminated C strings which are the keys that we can use to lookup the custom metadata map // The OrtAllocator will finally free the customMetadataMapKeysHandle using (var ortAllocationKeysArray = new OrtMemoryAllocation(allocator, customMetadataMapKeysHandle, 0)) using (var ortAllocationKeys = new DisposableList((int)numKeys)) { // Put all the handles to each key in the DisposableList to be disposed off in an exception-safe manner for (int i = 0; i < (int)numKeys; ++i) { ortAllocationKeys.Add(new OrtMemoryAllocation(allocator, Marshal.ReadIntPtr(customMetadataMapKeysHandle, IntPtr.Size * i), 0)); } // Process each key via the stored key handles foreach (var allocation in ortAllocationKeys) { IntPtr keyHandle = allocation.Pointer; IntPtr valueHandle = IntPtr.Zero; NativeApiStatus.VerifySuccess(NativeMethods.OrtModelMetadataLookupCustomMetadataMap(modelMetadataHandle, allocator.Pointer, keyHandle, out valueHandle)); using (var ortAllocationValue = new OrtMemoryAllocation(allocator, valueHandle, 0)) { var key = NativeOnnxValueHelper.StringFromNativeUtf8(keyHandle); var value = NativeOnnxValueHelper.StringFromNativeUtf8(valueHandle); // Put the key/value pair into the dictionary _customMetadataMap[key] = value; } } } } finally { // Free ModelMetadata handle NativeMethods.OrtReleaseModelMetadata(modelMetadataHandle); } } /// /// Producer name string /// /// producer name string public string ProducerName { get { return _producerName; } } /// /// Graph name for this model /// /// graph name string public string GraphName { get { return _graphName; } } /// /// Domain for this model /// /// domain name string public string Domain { get { return _domain; } } /// /// Unstructured model description /// /// description string public string Description { get { return _description; } } /// /// Unstructured graph description /// /// description string public string GraphDescription { get { return _graphDescription; } } /// /// Version number /// /// long version integer public long Version { get { return _version; } } /// /// Custom metadata key/value pairs /// /// An instance of a Dictionary public Dictionary CustomMetadataMap { get { return _customMetadataMap; } } } }