// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
///
/// Holds the options for configuring a TensorRT Execution Provider instance
///
public class OrtTensorRTProviderOptions : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}
private int _deviceId = 0;
private string _deviceIdStr = "device_id";
#region Constructor
///
/// Constructs an empty OrtTensorRTProviderOptions instance
///
public OrtTensorRTProviderOptions() : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorRTProviderOptions(out handle));
}
#endregion
#region Public Methods
///
/// Get TensorRT EP provider options
///
/// return C# UTF-16 encoded string
public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;
// Process provider options string
IntPtr providerOptions = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorRTProviderOptionsAsString(handle, allocator.Pointer, out providerOptions));
using (var ortAllocation = new OrtMemoryAllocation(allocator, providerOptions, 0))
{
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions);
}
}
///
/// Updates the configuration knobs of OrtTensorRTProviderOptions that will eventually be used to configure a TensorRT EP
/// Please refer to the following on different key/value pairs to configure a TensorRT EP and their meaning:
/// https://www.onnxruntime.ai/docs/reference/execution-providers/TensorRT-ExecutionProvider.html
///
/// key/value pairs used to configure a TensorRT Execution Provider
public void UpdateOptions(Dictionary providerOptions)
{
using (var cleanupList = new DisposableList())
{
var keysArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Keys.ToArray(), n => n, cleanupList);
var valuesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Values.ToArray(), n => n, cleanupList);
NativeApiStatus.VerifySuccess(NativeMethods.OrtUpdateTensorRTProviderOptions(handle, keysArray, valuesArray, (UIntPtr)providerOptions.Count));
if (providerOptions.ContainsKey(_deviceIdStr))
{
_deviceId = Int32.Parse(providerOptions[_deviceIdStr]);
}
}
}
///
/// Get device id of TensorRT EP.
///
/// device id
public int GetDeviceId()
{
return _deviceId;
}
#endregion
#region Public Properties
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
#endregion
#region Private Methods
#endregion
#region SafeHandle
///
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtTensorRTProviderOptions
///
/// always returns true
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseTensorRTProviderOptions(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
///
/// Holds the options for configuring a CUDA Execution Provider instance
///
public class OrtCUDAProviderOptions : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}
#region Constructor
///
/// Constructs an empty OrtCUDAroviderOptions instance
///
public OrtCUDAProviderOptions() : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCUDAProviderOptions(out handle));
}
#endregion
#region Public Methods
///
/// Get CUDA EP provider options
///
/// return C# UTF-16 encoded string
public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;
// Process provider options string
IntPtr providerOptions = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetCUDAProviderOptionsAsString(handle, allocator.Pointer, out providerOptions));
using (var ortAllocation = new OrtMemoryAllocation(allocator, providerOptions, 0))
{
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions);
}
}
///
/// Updates the configuration knobs of OrtCUDAProviderOptions that will eventually be used to configure a CUDA EP
/// Please refer to the following on different key/value pairs to configure a CUDA EP and their meaning:
/// https://www.onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html
///
/// key/value pairs used to configure a CUDA Execution Provider
public void UpdateOptions(Dictionary providerOptions)
{
using (var cleanupList = new DisposableList())
{
var keysArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Keys.ToArray(), n => n, cleanupList);
var valuesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Values.ToArray(), n => n, cleanupList);
NativeApiStatus.VerifySuccess(NativeMethods.OrtUpdateCUDAProviderOptions(handle, keysArray, valuesArray, (UIntPtr)providerOptions.Count));
}
}
#endregion
#region Public Properties
///
/// Overrides SafeHandle.IsInvalid
///
/// returns true if handle is equal to Zero
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
#endregion
#region Private Methods
#endregion
#region SafeHandle
///
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtCUDAProviderOptions
///
/// always returns true
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseCUDAProviderOptions(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
///
/// This helper class contains methods to handle values of provider options
///
public class ProviderOptionsValueHelper
{
///
/// Parse from string and save to dictionary
///
/// C# string
/// Dictionary instance to store the parsing result of s
public static void StringToDict(string s, Dictionary dict)
{
string[] paris = s.Split(';');
foreach (var p in paris)
{
string[] keyValue = p.Split('=');
if (keyValue.Length != 2)
{
throw new ArgumentException("Make sure input string contains key-value paris, e.g. key1=value1;key2=value2...", "s");
}
dict.Add(keyValue[0], keyValue[1]);
}
}
}
///
/// CoreML flags for use with SessionOptions
///
///
[Flags]
public enum CoreMLFlags : uint
{
COREML_FLAG_USE_NONE = 0x000,
COREML_FLAG_USE_CPU_ONLY = 0x001,
COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
COREML_FLAG_LAST = COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE,
}
///
/// NNAPI flags for use with SessionOptions
///
///
[Flags]
public enum NnapiFlags
{
NNAPI_FLAG_USE_NONE = 0x000,
NNAPI_FLAG_USE_FP16 = 0x001,
NNAPI_FLAG_USE_NCHW = 0x002,
NNAPI_FLAG_CPU_DISABLED = 0x004,
NNAPI_FLAG_CPU_ONLY = 0x008,
NNAPI_FLAG_LAST = NNAPI_FLAG_CPU_ONLY
}
}