Commit a144865d authored by gaoqiong's avatar gaoqiong
Browse files

update v1.14.0

parent cf1acfd2
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using Microsoft.ML.OnnxRuntime.Tensors;
using System;
using System.Buffers;
using System.Diagnostics;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// A type of data that OrtValue encapsulates.
/// </summary>
public enum OnnxValueType
{
ONNX_TYPE_UNKNOWN = 0, // Not set
ONNX_TYPE_TENSOR = 1, // It's a Tensor
ONNX_TYPE_SEQUENCE = 2, // It's an Onnx sequence which may be a sequence of Tensors/Maps/Sequences
ONNX_TYPE_MAP = 3, // It's a map
ONNX_TYPE_OPAQUE = 4, // It's an experimental Opaque object
ONNX_TYPE_SPARSETENSOR = 5, // It's a Sparse Tensor
}
/// <summary>
/// Represents a disposable OrtValue.
/// This class exposes a native instance of OrtValue.
/// The class implements IDisposable via SafeHandle and must
/// be disposed.
/// </summary>
public class OrtValue : SafeHandle
{
/// <summary>
/// Use factory methods to instantiate this class
/// </summary>
/// <param name="handle">Pointer to a native instance of OrtValue</param>
/// <param name="owned">Default true, own the raw handle. Otherwise, the handle is owned by another instance
/// However, we use this class to expose OrtValue that is owned by DisposableNamedOnnxValue
/// </param>
internal OrtValue(IntPtr handle, bool owned = true)
: base(handle, true)
{
IsOwned = owned;
}
internal IntPtr Handle { get { return handle; } }
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
#region NamedOnnxValue/DisposableOnnxValue accommodations
/// <summary>
/// This internal interface is used to transfer ownership elsewhere.
/// This instance must still be disposed in case there are other native
/// objects still owned. This is a convenience method to ensure that an underlying
/// OrtValue is disposed exactly once when exception is thrown.
/// </summary>
/// <returns></returns>
internal IntPtr Disown()
{
var ret = Handle;
handle = IntPtr.Zero;
IsOwned = false;
return ret;
}
internal bool IsOwned { get; private set; }
#endregion
/// <summary>
/// Factory method to construct an OrtValue of Tensor type on top of pre-allocated memory.
/// This can be a piece of native memory allocated by OrtAllocator (possibly on a device)
/// or a piece of pinned managed memory.
///
/// The resulting OrtValue does not own the underlying memory buffer and will not attempt to
/// deallocate it.
/// </summary>
/// <param name="memInfo">Memory Info. For managed memory it is a default cpu.
/// For Native memory must be obtained from the allocator or OrtMemoryAllocation instance</param>
/// <param name="elementType">DataType for the Tensor</param>
/// <param name="shape">Tensor shape</param>
/// <param name="dataBuffer">Pointer to a raw memory buffer</param>
/// <param name="bufferLength">Buffer length in bytes</param>
/// <returns>A disposable instance of OrtValue</returns>
public static OrtValue CreateTensorValueWithData(OrtMemoryInfo memInfo, TensorElementType elementType,
long[] shape,
IntPtr dataBuffer,
long bufferLength)
{
Type type;
int width;
if (!TensorElementTypeConverter.GetTypeAndWidth(elementType, out type, out width))
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Unable to query type information for data type: " + elementType.ToString());
}
if (elementType == TensorElementType.String)
{
throw new OnnxRuntimeException(ErrorCode.InvalidArgument,
"Cannot map managed strings buffer to native OrtValue");
}
var shapeSize = ArrayUtilities.GetSizeForShape(shape);
var requiredBufferSize = shapeSize * width;
if (requiredBufferSize > bufferLength)
{
var message = String.Format("Shape of: {0} elements requires a buffer of at least {1} bytes. Provided: {2} bytes",
shapeSize, requiredBufferSize, bufferLength);
throw new OnnxRuntimeException(ErrorCode.InvalidArgument, message);
}
IntPtr ortValueHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue(
memInfo.Pointer,
dataBuffer,
(UIntPtr)bufferLength,
shape,
(UIntPtr)shape.Length,
elementType,
out ortValueHandle
));
return new OrtValue(ortValueHandle);
}
/// <summary>
/// This is a factory method creates a native Onnxruntime OrtValue containing a tensor.
/// The method will attempt to pin managed memory so no copying occurs when data is passed down
/// to native code.
/// </summary>
/// <param name="value">Tensor object</param>
/// <param name="memoryHandle">For all tensor types but string tensors we endeavor to use managed memory
/// to avoid additional allocation and copy. This out parameter represents a chunk of pinned memory which will need
/// to be disposed when no longer needed. The lifespan of memoryHandle should eclipse the lifespan of the corresponding
/// OrtValue.
/// </param>
/// <param name="elementType">discovered tensor element type</param>
/// <returns>And instance of OrtValue constructed on top of the object</returns>
public static OrtValue CreateFromTensorObject(Object value, out MemoryHandle? memoryHandle,
out TensorElementType elementType)
{
// Check if this is a Tensor
if (!(value is TensorBase))
{
throw new NotSupportedException("The inference value " + nameof(value) + " is not of a supported type");
}
var tensorBase = value as TensorBase;
var typeInfo = tensorBase.GetTypeInfo();
if (typeInfo == null)
{
throw new OnnxRuntimeException(ErrorCode.RequirementNotRegistered, "BUG Check");
}
MemoryHandle? memHandle;
OrtValue ortValue = null;
int dataBufferLength = 0;
long[] shape = null;
int rank = 0;
TensorElementType elType = typeInfo.ElementType;
var typeSize = typeInfo.TypeSize;
if (typeInfo.IsString)
{
ortValue = CreateStringTensor(value as Tensor<string>);
memHandle = null;
}
else
{
switch (elType)
{
case TensorElementType.Float:
PinAsTensor(value as Tensor<float>, typeSize, out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Double:
PinAsTensor(value as Tensor<double>, typeSize, out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Int32:
PinAsTensor(value as Tensor<int>, typeSize, out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.UInt32:
PinAsTensor(value as Tensor<uint>, typeSize, out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Int64:
PinAsTensor(value as Tensor<long>, typeSize, out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.UInt64:
PinAsTensor(value as Tensor<ulong>, typeSize, out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Int16:
PinAsTensor(value as Tensor<short>, typeSize, out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.UInt16:
PinAsTensor(value as Tensor<ushort>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.UInt8:
PinAsTensor(value as Tensor<byte>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Int8:
PinAsTensor(value as Tensor<sbyte>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Bool:
PinAsTensor(value as Tensor<bool>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.Float16:
PinAsTensor(value as Tensor<Float16>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
case TensorElementType.BFloat16:
PinAsTensor(value as Tensor<BFloat16>, typeSize,
out memHandle, out dataBufferLength,
out shape, out rank);
break;
default:
throw new NotSupportedException("Element type: " + elType + " is not of a supported type");
}
try
{
Debug.Assert(memHandle.HasValue);
IntPtr dataBufferPointer = IntPtr.Zero;
unsafe
{
dataBufferPointer = (IntPtr)((MemoryHandle)memHandle).Pointer;
}
IntPtr nativeValue;
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorWithDataAsOrtValue(
OrtMemoryInfo.DefaultInstance.Pointer,
dataBufferPointer,
(UIntPtr)(dataBufferLength),
shape,
(UIntPtr)rank,
elType,
out nativeValue));
ortValue = new OrtValue(nativeValue);
}
catch (Exception)
{
memHandle?.Dispose();
throw;
}
}
memoryHandle = memHandle;
elementType = elType;
return ortValue;
}
private static void PinAsTensor<T>(
Tensor<T> tensor,
int elementSize,
out MemoryHandle? pinnedHandle,
out int dataBufferLength,
out long[] shape,
out int rank)
{
if (tensor == null)
{
throw new OnnxRuntimeException(ErrorCode.Fail, "Cast to Tensor<T> failed. BUG check!");
}
if (tensor.IsReversedStride)
{
//TODO: not sure how to support reverse stride. may be able to calculate the shape differently
throw new NotSupportedException(nameof(Tensor<T>) + " of reverseStride is not supported");
}
DenseTensor<T> dt = null;
if (tensor is DenseTensor<T>)
{
dt = tensor as DenseTensor<T>;
}
else
{
dt = tensor.ToDenseTensor();
}
pinnedHandle = dt.Buffer.Pin();
dataBufferLength = dt.Buffer.Length * elementSize;
shape = new long[dt.Dimensions.Length];
for (int i = 0; i < dt.Dimensions.Length; ++i)
{
shape[i] = dt.Dimensions[i];
}
rank = dt.Rank;
}
private static OrtValue CreateStringTensor(Tensor<string> tensor)
{
if (tensor == null)
{
throw new OnnxRuntimeException(ErrorCode.Fail, "Cast to Tensor<string> failed. BUG check!");
}
int totalLength = 0;
for (int i = 0; i < tensor.Length; i++)
{
totalLength += System.Text.Encoding.UTF8.GetByteCount(tensor.GetValue(i));
}
long[] shape = new long[tensor.Dimensions.Length];
for (int i = 0; i < tensor.Dimensions.Length; i++)
{
shape[i] = tensor.Dimensions[i];
}
// allocate the native tensor
IntPtr valueHandle = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorAsOrtValue(
OrtAllocator.DefaultInstance.Pointer,
shape,
(UIntPtr)(shape.Length),
TensorElementType.String,
out valueHandle
));
var ortValue = new OrtValue(valueHandle);
try
{
// fill the native tensor, using GetValue(index) from the Tensor<string>
var len = tensor.Length;
var nativeStrings = new IntPtr[len];
using (var pinnedHandles = new DisposableList<PinnedGCHandle>((int)len))
{
for (int i = 0; i < len; i++)
{
var utf8str = NativeOnnxValueHelper.StringToZeroTerminatedUtf8(tensor.GetValue(i));
var gcHandle = GCHandle.Alloc(utf8str, GCHandleType.Pinned);
nativeStrings[i] = gcHandle.AddrOfPinnedObject();
pinnedHandles.Add(new PinnedGCHandle(gcHandle));
}
using (var pinnedStrings = new PinnedGCHandle(GCHandle.Alloc(nativeStrings, GCHandleType.Pinned)))
NativeApiStatus.VerifySuccess(NativeMethods.OrtFillStringTensor(ortValue.Handle, nativeStrings, (UIntPtr)len));
}
}
catch (OnnxRuntimeException)
{
ortValue.Dispose();
throw;
}
return ortValue;
}
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtValue
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
// We have to surrender ownership to some legacy classes
// Or we never had that ownership to begin with
if (IsOwned)
{
NativeMethods.OrtReleaseValue(handle);
}
// Prevent use after disposal
handle = IntPtr.Zero;
return true;
}
// No need for the finalizer
#endregion
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// This class holds pre-packed weights of shared initializers to be shared across sessions using these initializers
/// and thereby provide memory savings by sharing the same pre-packed versions of these shared initializers
/// </summary>
public class PrePackedWeightsContainer : SafeHandle
{
/// <summary>
/// Constructs an empty PrePackedWeightsContainer
/// </summary>
public PrePackedWeightsContainer()
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreatePrepackedWeightsContainer(out handle));
}
/// <summary>
/// Internal accessor to call native methods
/// </summary>
internal IntPtr Pointer { get { return handle; } }
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to deallocate
/// a chunk of memory using the specified allocator.
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleasePrepackedWeightsContainer(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// Holds the options for configuring a TensorRT Execution Provider instance
/// </summary>
public class OrtTensorRTProviderOptions : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}
private int _deviceId = 0;
private string _deviceIdStr = "device_id";
#region Constructor
/// <summary>
/// Constructs an empty OrtTensorRTProviderOptions instance
/// </summary>
public OrtTensorRTProviderOptions() : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateTensorRTProviderOptions(out handle));
}
#endregion
#region Public Methods
/// <summary>
/// Get TensorRT EP provider options
/// </summary>
/// <returns> return C# UTF-16 encoded string </returns>
public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;
// Process provider options string
IntPtr providerOptions = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetTensorRTProviderOptionsAsString(handle, allocator.Pointer, out providerOptions));
using (var ortAllocation = new OrtMemoryAllocation(allocator, providerOptions, 0))
{
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions);
}
}
/// <summary>
/// Updates the configuration knobs of OrtTensorRTProviderOptions that will eventually be used to configure a TensorRT EP
/// Please refer to the following on different key/value pairs to configure a TensorRT EP and their meaning:
/// https://www.onnxruntime.ai/docs/reference/execution-providers/TensorRT-ExecutionProvider.html
/// </summary>
/// <param name="providerOptions">key/value pairs used to configure a TensorRT Execution Provider</param>
public void UpdateOptions(Dictionary<string, string> providerOptions)
{
using (var cleanupList = new DisposableList<IDisposable>())
{
var keysArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Keys.ToArray(), n => n, cleanupList);
var valuesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Values.ToArray(), n => n, cleanupList);
NativeApiStatus.VerifySuccess(NativeMethods.OrtUpdateTensorRTProviderOptions(handle, keysArray, valuesArray, (UIntPtr)providerOptions.Count));
if (providerOptions.ContainsKey(_deviceIdStr))
{
_deviceId = Int32.Parse(providerOptions[_deviceIdStr]);
}
}
}
/// <summary>
/// Get device id of TensorRT EP.
/// </summary>
/// <returns> device id </returns>
public int GetDeviceId()
{
return _deviceId;
}
#endregion
#region Public Properties
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
#endregion
#region Private Methods
#endregion
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtTensorRTProviderOptions
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseTensorRTProviderOptions(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
/// <summary>
/// Holds the options for configuring a CUDA Execution Provider instance
/// </summary>
public class OrtCUDAProviderOptions : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}
#region Constructor
/// <summary>
/// Constructs an empty OrtCUDAroviderOptions instance
/// </summary>
public OrtCUDAProviderOptions() : base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateCUDAProviderOptions(out handle));
}
#endregion
#region Public Methods
/// <summary>
/// Get CUDA EP provider options
/// </summary>
/// <returns> return C# UTF-16 encoded string </returns>
public string GetOptions()
{
var allocator = OrtAllocator.DefaultInstance;
// Process provider options string
IntPtr providerOptions = IntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeMethods.OrtGetCUDAProviderOptionsAsString(handle, allocator.Pointer, out providerOptions));
using (var ortAllocation = new OrtMemoryAllocation(allocator, providerOptions, 0))
{
return NativeOnnxValueHelper.StringFromNativeUtf8(providerOptions);
}
}
/// <summary>
/// Updates the configuration knobs of OrtCUDAProviderOptions that will eventually be used to configure a CUDA EP
/// Please refer to the following on different key/value pairs to configure a CUDA EP and their meaning:
/// https://www.onnxruntime.ai/docs/reference/execution-providers/CUDA-ExecutionProvider.html
/// </summary>
/// <param name="providerOptions">key/value pairs used to configure a CUDA Execution Provider</param>
public void UpdateOptions(Dictionary<string, string> providerOptions)
{
using (var cleanupList = new DisposableList<IDisposable>())
{
var keysArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Keys.ToArray(), n => n, cleanupList);
var valuesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(providerOptions.Values.ToArray(), n => n, cleanupList);
NativeApiStatus.VerifySuccess(NativeMethods.OrtUpdateCUDAProviderOptions(handle, keysArray, valuesArray, (UIntPtr)providerOptions.Count));
}
}
#endregion
#region Public Properties
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
#endregion
#region Private Methods
#endregion
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of OrtCUDAProviderOptions
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseCUDAProviderOptions(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
/// <summary>
/// This helper class contains methods to handle values of provider options
/// </summary>
public class ProviderOptionsValueHelper
{
/// <summary>
/// Parse from string and save to dictionary
/// </summary>
/// <param name="s">C# string</param>
/// <param name="dict">Dictionary instance to store the parsing result of s</param>
public static void StringToDict(string s, Dictionary<string, string> dict)
{
string[] paris = s.Split(';');
foreach (var p in paris)
{
string[] keyValue = p.Split('=');
if (keyValue.Length != 2)
{
throw new ArgumentException("Make sure input string contains key-value paris, e.g. key1=value1;key2=value2...", "s");
}
dict.Add(keyValue[0], keyValue[1]);
}
}
}
/// <summary>
/// CoreML flags for use with SessionOptions
/// </summary>
/// <see cref="https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/providers/coreml/coreml_provider_factory.h"/>
[Flags]
public enum CoreMLFlags : uint
{
COREML_FLAG_USE_NONE = 0x000,
COREML_FLAG_USE_CPU_ONLY = 0x001,
COREML_FLAG_ENABLE_ON_SUBGRAPH = 0x002,
COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE = 0x004,
COREML_FLAG_LAST = COREML_FLAG_ONLY_ENABLE_DEVICE_WITH_ANE,
}
/// <summary>
/// NNAPI flags for use with SessionOptions
/// </summary>
/// <see cref="https://github.com/microsoft/onnxruntime/blob/main/include/onnxruntime/core/providers/nnapi/nnapi_provider_factory.h"/>
[Flags]
public enum NnapiFlags
{
NNAPI_FLAG_USE_NONE = 0x000,
NNAPI_FLAG_USE_FP16 = 0x001,
NNAPI_FLAG_USE_NCHW = 0x002,
NNAPI_FLAG_CPU_DISABLED = 0x004,
NNAPI_FLAG_CPU_ONLY = 0x008,
NNAPI_FLAG_LAST = NNAPI_FLAG_CPU_ONLY
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// Sets various runtime options.
/// </summary>
public class RunOptions : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}
/// <summary>
/// Default __ctor. Creates default RuntimeOptions
/// </summary>
public RunOptions()
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateRunOptions(out handle));
}
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
/// <summary>
/// Log Severity Level for the session logs. Default = ORT_LOGGING_LEVEL_WARNING
/// </summary>
public OrtLoggingLevel LogSeverityLevel
{
get
{
return _logSeverityLevel;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsSetRunLogSeverityLevel(handle, value));
_logSeverityLevel = value;
}
}
private OrtLoggingLevel _logSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
/// <summary>
/// Log Verbosity Level for the session logs. Default = 0. Valid values are >=0.
/// This takes into effect only when the LogSeverityLevel is set to ORT_LOGGING_LEVEL_VERBOSE.
/// </summary>
public int LogVerbosityLevel
{
get
{
return _logVerbosityLevel;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsSetRunLogVerbosityLevel(handle, value));
_logVerbosityLevel = value;
}
}
private int _logVerbosityLevel = 0;
/// <summary>
/// Log tag to be used during the run. default = ""
/// </summary>
public string LogId
{
get
{
return _logId;
}
set
{
var logIdPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(value), GCHandleType.Pinned);
using (var pinnedlogIdName = new PinnedGCHandle(logIdPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsSetRunTag(handle, pinnedlogIdName.Pointer));
}
_logId = value;
}
}
private string _logId = "";
/// <summary>
/// Sets a flag to terminate all Run() calls that are currently using this RunOptions object
/// Default = false
/// </summary>
/// <value>terminate flag value</value>
public bool Terminate
{
get
{
return _terminate;
}
set
{
if (!_terminate && value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsSetTerminate(handle));
_terminate = true;
}
else if (_terminate && !value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRunOptionsUnsetTerminate(handle));
_terminate = false;
}
}
}
private bool _terminate = false; //value set to default value of the C++ RunOptions
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of RunOptions
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseRunOptions(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
}
\ No newline at end of file
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text;
namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// Graph optimization level to use with SessionOptions
/// [https://github.com/microsoft/onnxruntime/blob/main/docs/ONNX_Runtime_Graph_Optimizations.md]
/// </summary>
public enum GraphOptimizationLevel
{
ORT_DISABLE_ALL = 0,
ORT_ENABLE_BASIC = 1,
ORT_ENABLE_EXTENDED = 2,
ORT_ENABLE_ALL = 99
}
/// <summary>
/// Controls whether you want to execute operators in the graph sequentially or in parallel.
/// Usually when the model has many branches, setting this option to ExecutionMode.ORT_PARALLEL
/// will give you better performance.
/// See [ONNX_Runtime_Perf_Tuning.md] for more details.
/// </summary>
public enum ExecutionMode
{
ORT_SEQUENTIAL = 0,
ORT_PARALLEL = 1,
}
/// <summary>
/// Holds the options for creating an InferenceSession
/// </summary>
public class SessionOptions : SafeHandle
{
// Delay-loaded CUDA or cuDNN DLLs. Currently, delayload is disabled. See cmake/CMakeLists.txt for more information.
private static string[] cudaDelayLoadedLibs = { };
private static string[] trtDelayLoadedLibs = { };
#region Constructor and Factory methods
/// <summary>
/// Constructs an empty SessionOptions
/// </summary>
public SessionOptions()
: base(IntPtr.Zero, true)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateSessionOptions(out handle));
}
/// <summary>
/// A helper method to construct a SessionOptions object for CUDA execution.
/// Use only if CUDA is installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId"></param>
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns>
public static SessionOptions MakeSessionOptionWithCudaProvider(int deviceId = 0)
{
CheckCudaExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
options.AppendExecutionProvider_CUDA(deviceId);
return options;
}
/// <summary>
/// A helper method to construct a SessionOptions object for CUDA execution provider.
/// Use only if CUDA is installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="cudaProviderOptions">CUDA EP provider options</param>
/// <returns>A SessionsOptions() object configured for execution on provider options</returns>
public static SessionOptions MakeSessionOptionWithCudaProvider(OrtCUDAProviderOptions cudaProviderOptions)
{
CheckCudaExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
try
{
options.AppendExecutionProvider_CUDA(cudaProviderOptions);
return options;
}
catch (Exception)
{
options.Dispose();
throw;
}
}
/// <summary>
/// A helper method to construct a SessionOptions object for TensorRT execution.
/// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId"></param>
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns>
public static SessionOptions MakeSessionOptionWithTensorrtProvider(int deviceId = 0)
{
CheckTensorrtExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
try
{
options.AppendExecutionProvider_Tensorrt(deviceId);
options.AppendExecutionProvider_CUDA(deviceId);
return options;
}
catch (Exception)
{
options.Dispose();
throw;
}
}
/// <summary>
/// A helper method to construct a SessionOptions object for TensorRT execution provider.
/// Use only if CUDA/TensorRT are installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="trtProviderOptions">TensorRT EP provider options</param>
/// <returns>A SessionsOptions() object configured for execution on provider options</returns>
public static SessionOptions MakeSessionOptionWithTensorrtProvider(OrtTensorRTProviderOptions trtProviderOptions)
{
CheckTensorrtExecutionProviderDLLs();
SessionOptions options = new SessionOptions();
try
{
// Make sure that CUDA EP uses the same device id as TensorRT EP.
options.AppendExecutionProvider_Tensorrt(trtProviderOptions);
options.AppendExecutionProvider_CUDA(trtProviderOptions.GetDeviceId());
return options;
}
catch (Exception)
{
options.Dispose();
throw;
}
}
/// <summary>
/// A helper method to construct a SessionOptions object for TVM execution.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="settings">settings string, comprises of comma separated key:value pairs. default is empty</param>
/// <returns>A SessionsOptions() object configured for execution with TVM</returns>
public static SessionOptions MakeSessionOptionWithTvmProvider(String settings = "")
{
SessionOptions options = new SessionOptions();
options.AppendExecutionProvider_Tvm(settings);
return options;
}
/// <summary>
/// A helper method to construct a SessionOptions object for ROCM execution.
/// Use only if ROCM is installed and you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">Device Id</param>
/// <returns>A SessionsOptions() object configured for execution on deviceId</returns>
public static SessionOptions MakeSessionOptionWithRocmProvider(int deviceId = 0)
{
SessionOptions options = new SessionOptions();
options.AppendExecutionProvider_ROCM(deviceId);
return options;
}
#endregion
#region ExecutionProviderAppends
/// <summary>
/// Appends CPU EP to a list of available execution providers for the session.
/// </summary>
/// <param name="useArena">1 - use arena, 0 - do not use arena</param>
public void AppendExecutionProvider_CPU(int useArena = 1)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CPU(handle, useArena));
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="useArena">1 - use allocation arena, 0 - otherwise</param>
public void AppendExecutionProvider_Dnnl(int useArena = 1)
{
#if __MOBILE__
throw new NotSupportedException("The DNNL Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Dnnl(handle, useArena));
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">integer device ID</param>
public void AppendExecutionProvider_CUDA(int deviceId = 0)
{
#if __MOBILE__
throw new NotSupportedException("The CUDA Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_CUDA(handle, deviceId));
#endif
}
/// <summary>
/// Append a CUDA EP instance (based on specified configuration) to the SessionOptions instance.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="cudaProviderOptions">CUDA EP provider options</param>
public void AppendExecutionProvider_CUDA(OrtCUDAProviderOptions cudaProviderOptions)
{
#if __MOBILE__
throw new NotSupportedException("The CUDA Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_CUDA_V2(handle, cudaProviderOptions.Handle));
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">device identification</param>
public void AppendExecutionProvider_DML(int deviceId = 0)
{
#if __MOBILE__
throw new NotSupportedException("The DML Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_DML(handle, deviceId));
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">device identification, default empty string</param>
public void AppendExecutionProvider_OpenVINO(string deviceId = "")
{
#if __MOBILE__
throw new NotSupportedException("The OpenVINO Execution Provider is not supported in this build");
#else
var deviceIdPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(deviceId), GCHandleType.Pinned);
using (var pinnedDeviceIdName = new PinnedGCHandle(deviceIdPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_OpenVINO(handle, pinnedDeviceIdName.Pointer));
}
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">device identification</param>
public void AppendExecutionProvider_Tensorrt(int deviceId = 0)
{
#if __MOBILE__
throw new NotSupportedException("The TensorRT Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tensorrt(handle, deviceId));
#endif
}
/// <summary>
/// Append a TensorRT EP instance (based on specified configuration) to the SessionOptions instance.
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="trtProviderOptions">TensorRT EP provider options</param>
public void AppendExecutionProvider_Tensorrt(OrtTensorRTProviderOptions trtProviderOptions)
{
#if __MOBILE__
throw new NotSupportedException("The TensorRT Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider_TensorRT_V2(handle, trtProviderOptions.Handle));
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">Device Id</param>
public void AppendExecutionProvider_ROCM(int deviceId = 0)
{
#if __MOBILE__
throw new NotSupportedException("The ROCM Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(
NativeMethods.OrtSessionOptionsAppendExecutionProvider_ROCM(handle, deviceId));
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="deviceId">device identification</param>
public void AppendExecutionProvider_MIGraphX(int deviceId = 0)
{
#if __MOBILE__
throw new NotSupportedException($"The MIGraphX Execution Provider is not supported in this build");
#else
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_MIGraphX(handle, deviceId));
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="nnapiFlags">NNAPI specific flag mask</param>
public void AppendExecutionProvider_Nnapi(NnapiFlags nnapiFlags = NnapiFlags.NNAPI_FLAG_USE_NONE)
{
#if __ANDROID__
NativeApiStatus.VerifySuccess(
NativeMethods.OrtSessionOptionsAppendExecutionProvider_Nnapi(handle, (uint)nnapiFlags));
#else
throw new NotSupportedException("The NNAPI Execution Provider is not supported in this build");
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="coremlFlags">CoreML specific flags</param>
public void AppendExecutionProvider_CoreML(CoreMLFlags coremlFlags = CoreMLFlags.COREML_FLAG_USE_NONE)
{
#if __IOS__
NativeApiStatus.VerifySuccess(
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CoreML(handle, (uint)coremlFlags));
#else
#if __ENABLE_COREML__
// only attempt if this is OSX
if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
NativeApiStatus.VerifySuccess(
NativeMethods.OrtSessionOptionsAppendExecutionProvider_CoreML(handle, (uint)coremlFlags));
}
else
#endif
{
throw new NotSupportedException("The CoreML Execution Provider is not supported in this build");
}
#endif
}
/// <summary>
/// Use only if you have the onnxruntime package specific to this Execution Provider.
/// </summary>
/// <param name="settings">string with TVM specific settings</param>
public void AppendExecutionProvider_Tvm(string settings = "")
{
#if __MOBILE__
throw new NotSupportedException("The TVM Execution Provider is not supported in this build");
#else
var settingsPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(settings), GCHandleType.Pinned);
using (var pinnedSettingsName = new PinnedGCHandle(settingsPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSessionOptionsAppendExecutionProvider_Tvm(handle, pinnedSettingsName.Pointer));
}
#endif
}
/// <summary>
/// Append SNPE or XNNPACK execution provider
/// </summary>
/// <param name="providerName">Execution provider to add. 'SNPE' or 'XNNPACK' are currently supported.</param>
/// <param name="providerOptions">Optional key/value pairs to specify execution provider options.</param>
public void AppendExecutionProvider(string providerName, Dictionary<string, string> providerOptions = null)
{
if (providerName != "SNPE" && providerName != "XNNPACK")
{
throw new NotSupportedException(
"Only SNPE and XNNPACK execution providers can be enabled by this method.");
}
using (var cleanupList = new DisposableList<IDisposable>())
{
string[] ep = { providerName }; // put in array so we can use ConvertNamesToUtf8 for everything
var epArray = NativeOnnxValueHelper.ConvertNamesToUtf8(ep, n => n, cleanupList);
if (providerOptions == null)
{
providerOptions = new Dictionary<string, string>();
}
var keysArray = NativeOnnxValueHelper.ConvertNamesToUtf8(
providerOptions.Keys.ToArray(), n => n, cleanupList);
var valuesArray = NativeOnnxValueHelper.ConvertNamesToUtf8(
providerOptions.Values.ToArray(), n => n, cleanupList);
NativeApiStatus.VerifySuccess(NativeMethods.SessionOptionsAppendExecutionProvider(
handle, epArray[0], keysArray, valuesArray, (UIntPtr)providerOptions.Count));
}
}
#endregion //ExecutionProviderAppends
#region Public Methods
/// <summary>
/// (Deprecated) Loads a DLL named 'libraryPath' and looks for this entry point:
/// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
/// It then passes in the provided session options to this function along with the api base.
/// Deprecated in favor of RegisterCustomOpLibraryV2() because it provides users with the library handle
/// to release when all sessions relying on it are destroyed
/// </summary>
/// <param name="libraryPath">path to the custom op library</param>
[ObsoleteAttribute("RegisterCustomOpLibrary(...) is obsolete. Use RegisterCustomOpLibraryV2(...) instead.", false)]
public void RegisterCustomOpLibrary(string libraryPath)
{
IntPtr libraryHandle = IntPtr.Zero;
var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned);
using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle));
}
}
/// <summary>
/// Loads a DLL named 'libraryPath' and looks for this entry point:
/// OrtStatus* RegisterCustomOps(OrtSessionOptions* options, const OrtApiBase* api);
/// It then passes in the provided session options to this function along with the api base.
/// The handle to the loaded library is returned in 'libraryHandle'.
/// It can be unloaded by the caller after all sessions using the passed in
/// session options are destroyed, or if an error occurs and it is non null.
/// Hint: .NET Core 3.1 has a 'NativeLibrary' class that can be used to free the library handle
/// </summary>
/// <param name="libraryPath">Custom op library path</param>
/// <param name="libraryHandle">out parameter, library handle</param>
public void RegisterCustomOpLibraryV2(string libraryPath, out IntPtr libraryHandle)
{
var libraryPathPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(libraryPath), GCHandleType.Pinned);
using (var pinnedlibraryPath = new PinnedGCHandle(libraryPathPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtRegisterCustomOpsLibrary(handle, pinnedlibraryPath.Pointer, out libraryHandle));
}
}
/// <summary>
/// Add a pre-allocated initializer to a session. If a model contains an initializer with a name
/// that is same as the name passed to this API call, ORT will use this initializer instance
/// instead of deserializing one from the model file. This is useful when you want to share
/// the same initializer across sessions.
/// </summary>
/// <param name="name">name of the initializer</param>
/// <param name="ortValue">OrtValue containing the initializer. Lifetime of 'val' and the underlying initializer buffer must be
/// managed by the user (created using the CreateTensorWithDataAsOrtValue API) and it must outlive the session object</param>
public void AddInitializer(string name, OrtValue ortValue)
{
var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned);
using (var pinnedName = new PinnedGCHandle(utf8NamePinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddInitializer(handle, pinnedName.Pointer, ortValue.Handle));
}
}
/// <summary>
/// Set a single session configuration entry as a pair of strings
/// If a configuration with same key exists, this will overwrite the configuration with the given configValue
/// </summary>
/// <param name="configKey">config key name</param>
/// <param name="configValue">config key value</param>
public void AddSessionConfigEntry(string configKey, string configValue)
{
using (var pinnedConfigKeyName = new PinnedGCHandle(GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configKey), GCHandleType.Pinned)))
using (var pinnedConfigValueName = new PinnedGCHandle(GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(configValue), GCHandleType.Pinned)))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddSessionConfigEntry(handle,
pinnedConfigKeyName.Pointer, pinnedConfigValueName.Pointer));
}
}
/// <summary>
/// Override symbolic dimensions (by specific denotation strings) with actual values if known at session initialization time to enable
/// optimizations that can take advantage of fixed values (such as memory planning, etc)
/// </summary>
/// <param name="dimDenotation">denotation name</param>
/// <param name="dimValue">denotation value</param>
public void AddFreeDimensionOverride(string dimDenotation, long dimValue)
{
var utf8DimDenotationPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimDenotation), GCHandleType.Pinned);
using (var pinnedDimDenotation = new PinnedGCHandle(utf8DimDenotationPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverride(handle, pinnedDimDenotation.Pointer, dimValue));
}
}
/// <summary>
/// Override symbolic dimensions (by specific name strings) with actual values if known at session initialization time to enable
/// optimizations that can take advantage of fixed values (such as memory planning, etc)
/// </summary>
/// <param name="dimName">dimension name</param>
/// <param name="dimValue">dimension value</param>
public void AddFreeDimensionOverrideByName(string dimName, long dimValue)
{
var utf8DimNamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(dimName), GCHandleType.Pinned);
using (var pinnedDimName = new PinnedGCHandle(utf8DimNamePinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtAddFreeDimensionOverrideByName(handle, pinnedDimName.Pointer, dimValue));
}
}
#endregion
internal IntPtr Handle
{
get
{
return handle;
}
}
#region Public Properties
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
/// <summary>
/// Enables the use of the memory allocation patterns in the first Run() call for subsequent runs. Default = true.
/// </summary>
/// <value>returns enableMemoryPattern flag value</value>
public bool EnableMemoryPattern
{
get
{
return _enableMemoryPattern;
}
set
{
if (!_enableMemoryPattern && value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableMemPattern(handle));
_enableMemoryPattern = true;
}
else if (_enableMemoryPattern && !value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableMemPattern(handle));
_enableMemoryPattern = false;
}
}
}
private bool _enableMemoryPattern = true;
/// <summary>
/// Path prefix to use for output of profiling data
/// </summary>
public string ProfileOutputPathPrefix
{
get; set;
} = "onnxruntime_profile_"; // this is the same default in C++ implementation
/// <summary>
/// Enables profiling of InferenceSession.Run() calls. Default is false
/// </summary>
/// <value>returns _enableProfiling flag value</value>
public bool EnableProfiling
{
get
{
return _enableProfiling;
}
set
{
if (!_enableProfiling && value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableProfiling(handle, NativeOnnxValueHelper.GetPlatformSerializedString(ProfileOutputPathPrefix)));
_enableProfiling = true;
}
else if (_enableProfiling && !value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableProfiling(handle));
_enableProfiling = false;
}
}
}
private bool _enableProfiling = false;
/// <summary>
/// Set filepath to save optimized model after graph level transformations. Default is empty, which implies saving is disabled.
/// </summary>
/// <value>returns _optimizedModelFilePath flag value</value>
public string OptimizedModelFilePath
{
get
{
return _optimizedModelFilePath;
}
set
{
if (value != _optimizedModelFilePath)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetOptimizedModelFilePath(handle, NativeOnnxValueHelper.GetPlatformSerializedString(value)));
_optimizedModelFilePath = value;
}
}
}
private string _optimizedModelFilePath = "";
/// <summary>
/// Enables Arena allocator for the CPU memory allocations. Default is true.
/// </summary>
/// <value>returns _enableCpuMemArena flag value</value>
public bool EnableCpuMemArena
{
get
{
return _enableCpuMemArena;
}
set
{
if (!_enableCpuMemArena && value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtEnableCpuMemArena(handle));
_enableCpuMemArena = true;
}
else if (_enableCpuMemArena && !value)
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtDisableCpuMemArena(handle));
_enableCpuMemArena = false;
}
}
}
private bool _enableCpuMemArena = true;
/// <summary>
/// Log Id to be used for the session. Default is empty string.
/// </summary>
/// <value>returns _logId value</value>
public string LogId
{
get
{
return _logId;
}
set
{
var logIdPinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(value), GCHandleType.Pinned);
using (var pinnedlogIdName = new PinnedGCHandle(logIdPinned))
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionLogId(handle, pinnedlogIdName.Pointer));
}
_logId = value;
}
}
private string _logId = "";
/// <summary>
/// Log Severity Level for the session logs. Default = ORT_LOGGING_LEVEL_WARNING
/// </summary>
/// <value>returns _logSeverityLevel value</value>
public OrtLoggingLevel LogSeverityLevel
{
get
{
return _logSeverityLevel;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionLogSeverityLevel(handle, value));
_logSeverityLevel = value;
}
}
private OrtLoggingLevel _logSeverityLevel = OrtLoggingLevel.ORT_LOGGING_LEVEL_WARNING;
/// <summary>
/// Log Verbosity Level for the session logs. Default = 0. Valid values are >=0.
/// This takes into effect only when the LogSeverityLevel is set to ORT_LOGGING_LEVEL_VERBOSE.
/// </summary>
/// <value>returns _logVerbosityLevel value</value>
public int LogVerbosityLevel
{
get
{
return _logVerbosityLevel;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionLogVerbosityLevel(handle, value));
_logVerbosityLevel = value;
}
}
private int _logVerbosityLevel = 0;
/// <summary>
// Sets the number of threads used to parallelize the execution within nodes
// A value of 0 means ORT will pick a default
/// </summary>
/// <value>returns _intraOpNumThreads value</value>
public int IntraOpNumThreads
{
get
{
return _intraOpNumThreads;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetIntraOpNumThreads(handle, value));
_intraOpNumThreads = value;
}
}
private int _intraOpNumThreads = 0; // set to what is set in C++ SessionOptions by default;
/// <summary>
// Sets the number of threads used to parallelize the execution of the graph (across nodes)
// If sequential execution is enabled this value is ignored
// A value of 0 means ORT will pick a default
/// </summary>
/// <value>returns _interOpNumThreads value</value>
public int InterOpNumThreads
{
get
{
return _interOpNumThreads;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetInterOpNumThreads(handle, value));
_interOpNumThreads = value;
}
}
private int _interOpNumThreads = 0; // set to what is set in C++ SessionOptions by default;
/// <summary>
/// Sets the graph optimization level for the session. Default is set to ORT_ENABLE_ALL.
/// </summary>
/// <value>returns _graphOptimizationLevel value</value>
public GraphOptimizationLevel GraphOptimizationLevel
{
get
{
return _graphOptimizationLevel;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionGraphOptimizationLevel(handle, value));
_graphOptimizationLevel = value;
}
}
private GraphOptimizationLevel _graphOptimizationLevel = GraphOptimizationLevel.ORT_ENABLE_ALL;
/// <summary>
/// Sets the execution mode for the session. Default is set to ORT_SEQUENTIAL.
/// See [ONNX_Runtime_Perf_Tuning.md] for more details.
/// </summary>
/// <value>returns _executionMode value</value>
public ExecutionMode ExecutionMode
{
get
{
return _executionMode;
}
set
{
NativeApiStatus.VerifySuccess(NativeMethods.OrtSetSessionExecutionMode(handle, value));
_executionMode = value;
}
}
private ExecutionMode _executionMode = ExecutionMode.ORT_SEQUENTIAL;
#endregion
#region Private Methods
#if !__MOBILE__
// Declared, but called only if OS = Windows.
[DllImport("kernel32.dll")]
private static extern IntPtr LoadLibrary(string dllToLoad);
[DllImport("kernel32.dll")]
static extern uint GetSystemDirectory([Out] StringBuilder lpBuffer, uint uSize);
#else
private static IntPtr LoadLibrary(string dllToLoad)
{
throw new NotSupportedException();
}
static uint GetSystemDirectory([Out] StringBuilder lpBuffer, uint uSize)
{
throw new NotSupportedException();
}
#endif
private static bool CheckCudaExecutionProviderDLLs()
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
foreach (var dll in cudaDelayLoadedLibs)
{
IntPtr handle = LoadLibrary(dll);
if (handle != IntPtr.Zero)
continue;
var sysdir = new StringBuilder(String.Empty, 2048);
GetSystemDirectory(sysdir, (uint)sysdir.Capacity);
throw new OnnxRuntimeException(
ErrorCode.NoSuchFile,
$"kernel32.LoadLibrary():'{dll}' not found. CUDA is required for GPU execution. " +
$". Verify it is available in the system directory={sysdir}. Else copy it to the output folder."
);
}
}
return true;
}
private static bool CheckTensorrtExecutionProviderDLLs()
{
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
foreach (var dll in trtDelayLoadedLibs)
{
IntPtr handle = LoadLibrary(dll);
if (handle != IntPtr.Zero)
continue;
var sysdir = new StringBuilder(String.Empty, 2048);
GetSystemDirectory(sysdir, (uint)sysdir.Capacity);
throw new OnnxRuntimeException(
ErrorCode.NoSuchFile,
$"kernel32.LoadLibrary():'{dll}' not found. TensorRT/CUDA are required for GPU execution. " +
$". Verify it is available in the system directory={sysdir}. Else copy it to the output folder."
);
}
}
return true;
}
#endregion
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of SessionOptions
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeMethods.OrtReleaseSessionOptions(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
namespace Microsoft.ML.OnnxRuntime
{
/// <summary>
/// Helper to allow the creation/addition of session options based on pre-defined named entries.
/// </summary>
public static class SessionOptionsContainer
{
static Lazy<Action<SessionOptions>> _defaultHandler;
static readonly Dictionary<string, Lazy<Action<SessionOptions>>> _configurationHandlers =
new Dictionary<string, Lazy<Action<SessionOptions>>>();
static Lazy<Action<SessionOptions>> DefaultHandler =>
_defaultHandler != null
? _defaultHandler
: (_defaultHandler = new Lazy<Action<SessionOptions>>(() => (options) => { /* use as is */ }));
/// <summary>
/// Register the default handler. This is used when a configuration name is not provided.
/// </summary>
/// <param name="defaultHandler">Handler that applies the default settings to a SessionOptions instance.
/// </param>
public static void Register(Action<SessionOptions> defaultHandler) => _defaultHandler =
new Lazy<Action<SessionOptions>>(() => defaultHandler);
/// <summary>
/// Register a named handler.
/// </summary>
/// <param name="configuration">Configuration name.</param>
/// <param name="handler">
/// Handler that applies the settings for the configuration to a SessionOptions instance.
/// </param>
public static void Register(string configuration, Action<SessionOptions> handler) =>
_configurationHandlers[configuration] = new Lazy<Action<SessionOptions>>(() => handler);
/// <summary>
/// Create a SessionOptions instance with configuration applied.
/// </summary>
/// <param name="configuration">
/// Configuration to use.
/// If not provided, the default set of session options will be applied if useDefaultAsFallback is true.
/// </param>
/// <param name="useDefaultAsFallback">
/// If configuration is not provided or not found, use the default session options.
/// </param>
/// <returns>SessionOptions with configuration applied.</returns>
public static SessionOptions Create(string configuration = null, bool useDefaultAsFallback = true) =>
new SessionOptions().ApplyConfiguration(configuration, useDefaultAsFallback);
/// <summary>
/// Reset by removing all registered handlers.
/// </summary>
public static void Reset()
{
_defaultHandler = null;
_configurationHandlers.Clear();
}
/// <summary>
/// Apply a configuration to a SessionOptions instance.
/// </summary>
/// <param name="options">SessionOptions to apply configuration to.</param>
/// <param name="configuration">Configuration to use.</param>
/// <param name="useDefaultAsFallback">
/// Use the default configuration if 'configuration' is not provided or not found.
/// </param>
/// <returns>Updated SessionOptions instance.</returns>
public static SessionOptions ApplyConfiguration(this SessionOptions options, string configuration = null,
bool useDefaultAsFallback = true)
{
var handler = Resolve(configuration, useDefaultAsFallback);
handler(options);
return options;
}
static Action<SessionOptions> Resolve(string configuration = null, bool useDefaultAsFallback = true)
{
if (string.IsNullOrWhiteSpace(configuration))
return DefaultHandler.Value;
if (_configurationHandlers.TryGetValue(configuration, out var handler))
return handler.Value;
if (useDefaultAsFallback)
return DefaultHandler.Value;
throw new KeyNotFoundException($"Configuration not found for '{configuration}'");
}
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is copied and adapted from the following git repository -
// https://github.com/dotnet/corefx
// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1
// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayTensorExtensions.cs
// Original license statement below -
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
namespace Microsoft.ML.OnnxRuntime.Tensors
{
/// <summary>
/// A static class that houses static DenseTensor<T> extension methods
/// </summary>
public static class ArrayTensorExtensions
{
/// <summary>
/// Creates a copy of this single-dimensional array as a DenseTensor&lt;T&gt;
/// </summary>
/// <typeparam name="T">Type contained in the array to copy to the DenseTensor&lt;T&gt;.</typeparam>
/// <param name="array">The array to create a DenseTensor&lt;T&gt; from.</param>
/// <returns>A 1-dimensional DenseTensor&lt;T&gt; with the same length and content as <paramref name="array"/>.</returns>
public static DenseTensor<T> ToTensor<T>(this T[] array)
{
// DenseTensor<T>(Array, ...) is not efficient so do the copy here.
var dimensions = new int[] { array.Length };
T[] copy = new T[array.Length];
array.CopyTo(copy, 0);
return new DenseTensor<T>(new Memory<T>(copy), dimensions);
}
/// <summary>
/// Creates a copy of this two-dimensional array as a DenseTensor&lt;T&gt;
/// </summary>
/// <typeparam name="T">Type contained in the array to copy to the DenseTensor&lt;T&gt;.</typeparam>
/// <param name="array">The array to create a DenseTensor&lt;T&gt; from.</param>
/// <param name="reverseStride">False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): row-major. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): column-major.</param>
/// <returns>A 2-dimensional DenseTensor&lt;T&gt; with the same dimensions and content as <paramref name="array"/>.</returns>
public static DenseTensor<T> ToTensor<T>(this T[,] array, bool reverseStride = false)
{
if (reverseStride)
{
// we need logic from the DenseTensor ctor to be applied during copying
return new DenseTensor<T>(array, reverseStride);
}
else
{
// it's more efficient to copy and flatten to 1D T[] and construct DenseTensor with Memory<T>
T[] copy = new T[array.Length];
var dimensions = new int[] { array.GetLength(0), array.GetLength(1) };
long idx = 0;
foreach (var item in array)
{
copy[idx++] = item;
}
return new DenseTensor<T>(new Memory<T>(copy), dimensions);
}
}
/// <summary>
/// Creates a copy of this three-dimensional array as a DenseTensor&lt;T&gt;
/// </summary>
/// <typeparam name="T">Type contained in the array to copy to the DenseTensor&lt;T&gt;.</typeparam>
/// <param name="array">The array to create a DenseTensor&lt;T&gt; from.</param>
/// <param name="reverseStride">False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor.</param>
/// <returns>A 3-dimensional DenseTensor&lt;T&gt; with the same dimensions and content as <paramref name="array"/>.</returns>
public static DenseTensor<T> ToTensor<T>(this T[,,] array, bool reverseStride = false)
{
if (reverseStride)
{
// we need logic from the DenseTensor ctor to be applied during copying
return new DenseTensor<T>(array, reverseStride);
}
else
{
// it's more efficient to copy and flatten to 1D T[] and construct DenseTensor with Memory<T>
T[] copy = new T[array.Length];
var dimensions = new int[] { array.GetLength(0), array.GetLength(1), array.GetLength(2) };
long idx = 0;
foreach (var item in array)
{
copy[idx++] = item;
}
return new DenseTensor<T>(new Memory<T>(copy), dimensions);
}
}
/// <summary>
/// Creates a copy of this four-dimensional array as a DenseTensor&lt;T&gt;
/// </summary>
/// <typeparam name="T">Type contained in the array to copy to the DenseTensor&lt;T&gt;.</typeparam>
/// <param name="array">The array to create a DenseTensor&lt;T&gt; from.</param>
/// <param name="reverseStride">False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor.</param>
/// <returns>A 4-dimensional DenseTensor&lt;T&gt; with the same dimensions and content as <paramref name="array"/>.</returns>
public static DenseTensor<T> ToTensor<T>(this T[,,,] array, bool reverseStride = false)
{
if (reverseStride)
{
// we need logic from the DenseTensor ctor to be applied during copying
return new DenseTensor<T>(array, reverseStride);
}
else
{
// it's more efficient to copy and flatten to 1D T[] and construct DenseTensor with Memory<T>
T[] copy = new T[array.Length];
var dimensions = new int[] {
array.GetLength(0), array.GetLength(1), array.GetLength(2), array.GetLength(3) };
long idx = 0;
foreach (var item in array)
{
copy[idx++] = item;
}
return new DenseTensor<T>(new Memory<T>(copy), dimensions);
}
}
/// <summary>
/// Creates a copy of this n-dimensional array as a DenseTensor&lt;T&gt;
/// </summary>
/// <typeparam name="T">Type contained in the array to copy to the DenseTensor&lt;T&gt;.</typeparam>
/// <param name="array">The array to create a DenseTensor&lt;T&gt; from.</param>
/// <param name="reverseStride">False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor.</param>
/// <returns>A n-dimensional DenseTensor&lt;T&gt; with the same dimensions and content as <paramref name="array"/>.</returns>
public static DenseTensor<T> ToTensor<T>(this Array array, bool reverseStride = false)
{
return new DenseTensor<T>(array, reverseStride);
}
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is copied and adapted from the following git repository -
// https://github.com/dotnet/corefx
// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1
// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/ArrayUtilities.cs
// Original license statement below -
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Diagnostics;
using System;
namespace Microsoft.ML.OnnxRuntime.Tensors
{
internal static class ArrayUtilities
{
public const int StackallocMax = 16;
public static long GetSizeForShape(long[] shape)
{
long product = 1;
foreach( var dim in shape)
{
if (dim < 0)
{
throw new ArgumentOutOfRangeException("Shape must not have negative elements:" + dim);
}
product *= dim;
}
return product;
}
public static long GetProduct(ReadOnlySpan<int> dimensions, int startIndex = 0)
{
long product = 1;
for (int i = startIndex; i < dimensions.Length; i++)
{
if (dimensions[i] < 0)
{
throw new ArgumentOutOfRangeException($"{nameof(dimensions)}[{i}]");
}
// we use a long which should be much larger than is ever used here,
// but still force checked
checked
{
product *= dimensions[i];
}
}
return product;
}
public static bool IsAscending(ReadOnlySpan<int> values)
{
for (int i = 1; i < values.Length; i++)
{
if (values[i] < values[i - 1])
{
return false;
}
}
return true;
}
public static bool IsDescending(ReadOnlySpan<int> values)
{
for (int i = 1; i < values.Length; i++)
{
if (values[i] > values[i - 1])
{
return false;
}
}
return true;
}
/// <summary>
/// Gets the set of strides that can be used to calculate the offset of n-dimensions in a 1-dimensional layout
/// </summary>
/// <param name="dimensions"></param>
/// <param name="reverseStride"></param>
/// <returns></returns>
public static int[] GetStrides(ReadOnlySpan<int> dimensions, bool reverseStride = false)
{
int[] strides = new int[dimensions.Length];
if (dimensions.Length == 0)
{
return strides;
}
int stride = 1;
if (reverseStride)
{
for (int i = 0; i < strides.Length; i++)
{
strides[i] = stride;
stride *= dimensions[i];
}
}
else
{
for (int i = strides.Length - 1; i >= 0; i--)
{
strides[i] = stride;
stride *= dimensions[i];
}
}
return strides;
}
public static void SplitStrides(int[] strides, int[] splitAxes, int[] newStrides, int stridesOffset, int[] splitStrides, int splitStridesOffset)
{
int newStrideIndex = 0;
for (int i = 0; i < strides.Length; i++)
{
int stride = strides[i];
bool isSplit = false;
for (int j = 0; j < splitAxes.Length; j++)
{
if (splitAxes[j] == i)
{
splitStrides[splitStridesOffset + j] = stride;
isSplit = true;
break;
}
}
if (!isSplit)
{
newStrides[stridesOffset + newStrideIndex++] = stride;
}
}
}
/// <summary>
/// Calculates the 1-d index for n-d indices in layout specified by strides.
/// </summary>
/// <param name="strides"></param>
/// <param name="indices"></param>
/// <param name="startFromDimension"></param>
/// <returns></returns>
public static int GetIndex(int[] strides, ReadOnlySpan<int> indices, int startFromDimension = 0)
{
Debug.Assert(strides.Length == indices.Length);
int index = 0;
for (int i = startFromDimension; i < indices.Length; i++)
{
index += strides[i] * indices[i];
}
return index;
}
/// <summary>
/// Calculates the n-d indices from the 1-d index in a layout specificed by strides
/// </summary>
/// <param name="strides"></param>
/// <param name="reverseStride"></param>
/// <param name="index"></param>
/// <param name="indices"></param>
/// <param name="startFromDimension"></param>
public static void GetIndices(ReadOnlySpan<int> strides, bool reverseStride, int index, int[] indices, int startFromDimension = 0)
{
Debug.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides");
Debug.Assert(strides.Length == indices.Length);
// scalar tensor - nothing to process
if (indices.Length == 0)
{
return;
}
int remainder = index;
for (int i = startFromDimension; i < strides.Length; i++)
{
// reverse the index for reverseStride so that we divide by largest stride first
var nIndex = reverseStride ? strides.Length - 1 - i : i;
var stride = strides[nIndex];
indices[nIndex] = remainder / stride;
remainder %= stride;
}
}
/// <summary>
/// Calculates the n-d indices from the 1-d index in a layout specificed by strides
/// </summary>
/// <param name="strides"></param>
/// <param name="reverseStride"></param>
/// <param name="index"></param>
/// <param name="indices"></param>
/// <param name="startFromDimension"></param>
public static void GetIndices(ReadOnlySpan<int> strides, bool reverseStride, int index, Span<int> indices, int startFromDimension = 0)
{
Debug.Assert(reverseStride ? IsAscending(strides) : IsDescending(strides), "Index decomposition requires ordered strides");
Debug.Assert(strides.Length == indices.Length);
// scalar tensor - nothing to process
if (indices.Length == 0)
{
return;
}
int remainder = index;
for (int i = startFromDimension; i < strides.Length; i++)
{
// reverse the index for reverseStride so that we divide by largest stride first
var nIndex = reverseStride ? strides.Length - 1 - i : i;
var stride = strides[nIndex];
indices[nIndex] = remainder / stride;
remainder %= stride;
}
}
/// <summary>
/// Takes an 1-d index over n-d sourceStrides and recalculates it assuming same n-d coordinates over a different n-d strides
/// </summary>
public static int TransformIndexByStrides(int index, int[] sourceStrides, bool sourceReverseStride, int[] transformStrides)
{
Debug.Assert(index >= 0);
Debug.Assert(sourceReverseStride ? IsAscending(sourceStrides) : IsDescending(sourceStrides), "Index decomposition requires ordered strides");
Debug.Assert(sourceStrides.Length == transformStrides.Length);
// scalar tensor
if (sourceStrides.Length == 0)
{
Debug.Assert(index == 0, "Index has to be zero for a scalar tensor");
return 0;
}
int transformIndex = 0;
int remainder = index;
for (int i = 0; i < sourceStrides.Length; i++)
{
// reverse the index for reverseStride so that we divide by largest stride first
var nIndex = sourceReverseStride ? sourceStrides.Length - 1 - i: i;
var sourceStride = sourceStrides[nIndex];
var transformStride = transformStrides[nIndex];
transformIndex += transformStride * (remainder / sourceStride);
remainder %= sourceStride;
}
return transformIndex;
}
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is copied and adapted from the following git repository -
// https://github.com/dotnet/corefx
// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1
// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/DenseTensor.cs
// Original license statement below -
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Runtime.InteropServices;
using System;
namespace Microsoft.ML.OnnxRuntime.Tensors
{
/// <summary>
/// Represents a multi-dimensional collection of objects of type T that can be accessed by indices.
/// DenseTensor stores values in a contiguous sequential block of memory where all values are represented.
/// </summary>
/// <typeparam name="T">
/// Type contained within the Tensor. Typically a value type such as int, double, float, etc.
/// </typeparam>
public class DenseTensor<T> : Tensor<T>
{
private readonly Memory<T> memory;
internal DenseTensor(Array fromArray, bool reverseStride = false) : base(fromArray, reverseStride)
{
// copy initial array
var backingArray = new T[fromArray.Length];
int index = 0;
if (reverseStride)
{
// Array is always row-major
var sourceStrides = ArrayUtilities.GetStrides(dimensions);
foreach (var item in fromArray)
{
var destIndex = ArrayUtilities.TransformIndexByStrides(index++, sourceStrides, false, strides);
backingArray[destIndex] = (T)item;
}
}
else
{
foreach (var item in fromArray)
{
backingArray[index++] = (T)item;
}
}
memory = backingArray;
}
/// <summary>
/// Initializes a rank-1 Tensor using the specified <paramref name="length"/>.
/// </summary>
/// <param name="length">Size of the 1-dimensional tensor</param>
public DenseTensor(int length) : base(length)
{
memory = new T[length];
}
/// <summary>
/// Initializes a rank-n Tensor using the dimensions specified in <paramref name="dimensions"/>.
/// </summary>
/// <param name="dimensions">
/// An span of integers that represent the size of each dimension of the DenseTensor to create.
/// </param>
/// <param name="reverseStride">
/// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension
/// is most minor (closest together): akin to row-major in a rank-2 tensor.
/// True to indicate that the last dimension is most major (farthest apart) and the first dimension is most
/// minor (closest together): akin to column-major in a rank-2 tensor.
/// </param>
public DenseTensor(ReadOnlySpan<int> dimensions, bool reverseStride = false) : base(dimensions, reverseStride)
{
memory = new T[Length];
}
/// <summary>
/// Constructs a new DenseTensor of the specified dimensions, wrapping existing backing memory for the contents.
/// </summary>
/// <param name="memory"></param>
/// <param name="dimensions">
/// An span of integers that represent the size of each dimension of the DenseTensor to create.</param>
/// <param name="reverseStride">
/// False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension
/// is most minor (closest together): akin to row-major in a rank-2 tensor.
/// True to indicate that the last dimension is most major (farthest apart) and the first dimension is most
/// minor (closest together): akin to column-major in a rank-2 tensor.
/// </param>
public DenseTensor(Memory<T> memory, ReadOnlySpan<int> dimensions, bool reverseStride = false)
: base(dimensions, reverseStride)
{
this.memory = memory;
if (Length != memory.Length)
{
throw new ArgumentException(
$"Length of {nameof(memory)} ({memory.Length}) must match product of " +
$"{nameof(dimensions)} ({Length}).");
}
}
/// <summary>
/// Memory storing backing values of this tensor.
/// </summary>
public Memory<T> Buffer => memory;
/// <summary>
/// Gets the value at the specified index, where index is a linearized version of n-dimension indices
/// using strides. For a scalar, use index = 0
/// </summary>
/// <param name="index">An integer index computed as a dot-product of indices.</param>
/// <returns>The value at the specified position in this Tensor.</returns>
public override T GetValue(int index)
{
return Buffer.Span[index];
}
/// <summary>
/// Sets the value at the specified index, where index is a linearized version of n-dimension indices
/// using strides. For a scalar, use index = 0
/// </summary>
/// <param name="index">An integer index computed as a dot-product of indices.</param>
/// <param name="value">The new value to set at the specified position in this Tensor.</param>
public override void SetValue(int index, T value)
{
Buffer.Span[index] = value;
}
/// <summary>
/// Overrides Tensor.CopyTo(). Copies the content of the Tensor
/// to the specified array starting with arrayIndex
/// </summary>
/// <param name="array">destination array</param>
/// <param name="arrayIndex">start index</param>
protected override void CopyTo(T[] array, int arrayIndex)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if (array.Length < arrayIndex + Length)
{
throw new ArgumentException(
"The number of elements in the Tensor is greater than the available space from index to " +
"the end of the destination array.", nameof(array));
}
Buffer.Span.CopyTo(array.AsSpan(arrayIndex));
}
/// <summary>
/// Determines the index of a specific item in the Tensor&lt;T&gt;.
/// </summary>
/// <param name="item">Object to locate</param>
/// <returns>The index of item if found in the tensor; otherwise, -1</returns>
protected override int IndexOf(T item)
{
// TODO: use Span.IndexOf when/if it removes the IEquatable type constraint
if (MemoryMarshal.TryGetArray<T>(Buffer, out var arraySegment))
{
var result = Array.IndexOf(arraySegment.Array, item, arraySegment.Offset, arraySegment.Count);
if (result != -1)
{
result -= arraySegment.Offset;
}
return result;
}
else
{
return base.IndexOf(item);
}
}
/// <summary>
/// Creates a shallow copy of this tensor, with new backing storage.
/// </summary>
/// <returns>A shallow copy of this tensor.</returns>
public override Tensor<T> Clone()
{
// create copy
return new DenseTensor<T>(new Memory<T>(memory.ToArray()), dimensions, IsReversedStride);
}
/// <summary>
/// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor
/// with elements initialized to their default value.
/// </summary>
/// <typeparam name="TResult">Type contained in the returned Tensor.</typeparam>
/// <param name="dimensions">
/// An span of integers that represent the size of each dimension of the DenseTensor to create.</param>
/// <returns>A new tensor with the same layout as this tensor but different type and dimensions.</returns>
public override Tensor<TResult> CloneEmpty<TResult>(ReadOnlySpan<int> dimensions)
{
return new DenseTensor<TResult>(dimensions, IsReversedStride);
}
/// <summary>
/// Reshapes the current tensor to new dimensions, using the same backing storage.
/// </summary>
/// <param name="dimensions">
/// An span of integers that represent the size of each dimension of the DenseTensor to create.</param>
/// <returns>A new tensor that reinterprets backing Buffer of this tensor with different dimensions.</returns>
public override Tensor<T> Reshape(ReadOnlySpan<int> dimensions)
{
var newSize = ArrayUtilities.GetProduct(dimensions);
if (newSize != Length)
{
throw new ArgumentException($"Cannot reshape array due to mismatch in lengths, " +
"currently {Length} would become {newSize}.", nameof(dimensions));
}
return new DenseTensor<T>(Buffer, dimensions, IsReversedStride);
}
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
// This file is copied and adapted from the following git repository -
// https://github.com/dotnet/corefx
// Commit ID: bdd0814360d4c3a58860919f292a306242f27da1
// Path: /src/System.Numerics.Tensors/src/System/Numerics/Tensors/Tensor.cs
// Original license statement below -
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Text;
namespace Microsoft.ML.OnnxRuntime.Tensors
{
/// <summary>
/// Supported Tensor DataType
/// </summary>
public enum TensorElementType
{
Float = 1,
UInt8 = 2,
Int8 = 3,
UInt16 = 4,
Int16 = 5,
Int32 = 6,
Int64 = 7,
String = 8,
Bool = 9,
Float16 = 10,
Double = 11,
UInt32 = 12,
UInt64 = 13,
Complex64 = 14,
Complex128 = 15,
BFloat16 = 16,
DataTypeMax = 17
}
/// <summary>
/// This value type represents A Float16 value
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
/// </summary>
public struct Float16
{
/// <summary>
/// float16 representation bits
/// </summary>
public ushort value;
/// <summary>
/// Ctor
/// </summary>
/// <param name="v"></param>
public Float16(ushort v)
{
value = v;
}
/// <summary>
/// Converts to ushort
/// </summary>
/// <param name="f">instance of Float16</param>
/// <returns>value member</returns>
public static implicit operator ushort (Float16 f) { return f.value; }
/// <summary>
/// Converts a 16-bit unsigned integer to a Float16.
/// </summary>
/// <param name="value">A 16-bit unsigned integer.</param>
/// <returns>A Float16 that represents the converted 16-bit unsigned integer.</returns>
public static implicit operator Float16(ushort value) { return new Float16(value); }
/// <summary>
/// Compares values of two Float16 for binary equality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator ==(Float16 lhs, Float16 rhs) { return lhs.value == rhs.value; }
/// <summary>
/// Compares values of two Float16 for binary inequality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator !=(Float16 lhs, Float16 rhs) { return lhs.value != rhs.value; }
/// <summary>
/// Returns a value indicating whether this instance and other Float16 represent the same value.
/// </summary>
/// <param name="other">A Float16 object to compare to this instance.</param>
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
public bool Equals(Float16 other)
{
return (other == this);
}
/// <summary>
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
/// </summary>
/// <param name="obj">An System.Object.</param>
/// <returns>true if obj is Float16 and its value is equal to this instance; otherwise, false.</returns>
public override bool Equals(object obj)
{
bool result = false;
if (obj is Float16)
{
Float16 fl16 = (Float16)obj;
result = (fl16 == this);
}
return result;
}
/// <summary>
/// Returns the hash code for this instance.
/// </summary>
/// <returns>A 32-bit signed integer hash code.</returns>
public override int GetHashCode()
{
return value.GetHashCode();
}
}
/// <summary>
/// This value type represents A BFloat16 value
/// it is blittable as defined in https://docs.microsoft.com/en-us/dotnet/framework/interop/blittable-and-non-blittable-types
/// and as such, represented the same way in managed and native memories. This means that arrays of this type
/// do not have to be copied to be passed to native memory but simply pinnned and read by native code. Thus,
/// one can create a Tensor on top of an array of these structures and feed it directly to Onnxruntime library.
/// Binary wise, it is the same as ushort[] (uint16_t in C++). However, we would like a separate type for type dispatching.
/// </summary>
public struct BFloat16
{
/// <summary>
/// bfloat16 representation bits
/// </summary>
public ushort value;
/// <summary>
/// Ctor
/// </summary>
/// <param name="v"></param>
public BFloat16(ushort v)
{
value = v;
}
/// <summary>
/// Converts to ushort
/// </summary>
/// <param name="bf">instance of BFloat16</param>
/// <returns>value member</returns>
public static implicit operator ushort(BFloat16 bf) { return bf.value; }
/// <summary>
/// Converts a 16-bit unsigned integer to a BFloat16.
/// </summary>
/// <param name="value">A 16-bit unsigned integer.</param>
/// <returns>A BFloat16 that represents the converted 16-bit unsigned integer.</returns>
public static implicit operator BFloat16(ushort value) { return new BFloat16(value); }
/// <summary>
/// Compares values of two BFloat16 for binary equality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator ==(BFloat16 lhs, BFloat16 rhs) { return lhs.value == rhs.value; }
/// <summary>
/// Compares values of two BFloat16 for binary inequality
/// </summary>
/// <param name="lhs"></param>
/// <param name="rhs"></param>
/// <returns>result of value comparisons</returns>
public static bool operator !=(BFloat16 lhs, BFloat16 rhs) { return lhs.value != rhs.value; }
/// <summary>
/// Returns a value indicating whether this instance and other BFloat16 represent the same value.
/// </summary>
/// <param name="other">A BFloat16 object to compare to this instance.</param>
/// <returns>true if other.value is equal to this instance; otherwise, false.</returns>
public bool Equals(BFloat16 other)
{
return (other == this);
}
/// <summary>
/// Returns a value indicating whether this instance and a specified System.Object
/// represent the same type and value.
/// </summary>
/// <param name="obj">An System.Object.</param>
/// <returns>true if obj is BFloat16 its value is equal to this instance; otherwise, false.</returns>
public override bool Equals(object obj)
{
bool result = false;
if (obj is BFloat16)
{
BFloat16 bfl16 = (BFloat16)obj;
result = (bfl16 == this);
}
return result;
}
/// <summary>
/// Returns the hash code for this instance.
/// </summary>
/// <returns>A 32-bit signed integer hash code.</returns>
public override int GetHashCode()
{
return value.GetHashCode();
}
}
/// <summary>
/// Helps typecasting. Holds Tensor element type traits.
/// </summary>
public class TensorTypeInfo
{
/// <summary>
/// TensorElementType enum
/// </summary>
/// <value>type enum value</value>
public TensorElementType ElementType { get; private set; }
/// <summary>
/// Size of the stored primitive type in bytes
/// </summary>
/// <value>size in bytes</value>
public int TypeSize { get; private set; }
/// <summary>
/// Is the type is a string
/// </summary>
/// <value>true if Tensor element type is a string</value>
public bool IsString { get { return ElementType == TensorElementType.String; } }
/// <summary>
/// Ctor
/// </summary>
/// <param name="elementType">TensorElementType value</param>
/// <param name="typeSize">size fo the type in bytes</param>
public TensorTypeInfo(TensorElementType elementType, int typeSize)
{
ElementType = elementType;
TypeSize = typeSize;
}
}
/// <summary>
/// Holds TensorElement traits
/// </summary>
public class TensorElementTypeInfo
{
/// <summary>
/// Tensor element type
/// </summary>
/// <value>System.Type</value>
public Type TensorType { get; private set; }
/// <summary>
/// Size of the stored primitive type in bytes
/// </summary>
/// <value>size in bytes</value>
public int TypeSize { get; private set; }
/// <summary>
/// Is the type is a string
/// </summary>
/// <value>true if Tensor element type is a string</value>
public bool IsString { get; private set; }
/// <summary>
/// Ctor
/// </summary>
/// <param name="type">Tensor element type</param>
/// <param name="typeSize">typesize</param>
public TensorElementTypeInfo(Type type, int typeSize)
{
TensorType = type;
TypeSize = typeSize;
IsString = type == typeof(string);
}
}
/// <summary>
/// This class is a base for all Tensors. It hosts maps with type traits.
/// </summary>
public class TensorBase
{
private static readonly Dictionary<Type, TensorTypeInfo> typeInfoMap;
private static readonly Dictionary<TensorElementType, TensorElementTypeInfo> tensorElementTypeInfoMap;
static TensorBase () {
typeInfoMap = new Dictionary<Type, TensorTypeInfo>()
{
{ typeof(float), new TensorTypeInfo( TensorElementType.Float, sizeof(float)) },
{ typeof(byte), new TensorTypeInfo( TensorElementType.UInt8, sizeof(byte)) },
{ typeof(sbyte), new TensorTypeInfo( TensorElementType.Int8, sizeof(sbyte)) },
{ typeof(ushort), new TensorTypeInfo( TensorElementType.UInt16, sizeof(ushort)) },
{ typeof(short), new TensorTypeInfo( TensorElementType.Int16, sizeof(short)) },
{ typeof(int), new TensorTypeInfo( TensorElementType.Int32, sizeof(int)) },
{ typeof(long), new TensorTypeInfo( TensorElementType.Int64, sizeof(long)) },
{ typeof(string), new TensorTypeInfo( TensorElementType.String, -1) },
{ typeof(bool), new TensorTypeInfo( TensorElementType.Bool, sizeof(bool)) },
{ typeof(Float16), new TensorTypeInfo( TensorElementType.Float16, sizeof(ushort)) },
{ typeof(double), new TensorTypeInfo( TensorElementType.Double, sizeof(double)) },
{ typeof(uint), new TensorTypeInfo( TensorElementType.UInt32, sizeof(uint)) },
{ typeof(ulong), new TensorTypeInfo( TensorElementType.UInt64, sizeof(ulong)) },
{ typeof(BFloat16), new TensorTypeInfo( TensorElementType.BFloat16, sizeof(ushort)) }
};
tensorElementTypeInfoMap = new Dictionary<TensorElementType, TensorElementTypeInfo>();
foreach(var info in typeInfoMap)
{
tensorElementTypeInfoMap.Add(info.Value.ElementType, new TensorElementTypeInfo(info.Key, info.Value.TypeSize));
}
}
private readonly Type _primitiveType;
/// <summary>
/// Constructs TensorBae
/// </summary>
/// <param name="primitiveType">primitive type the deriving class is using</param>
protected TensorBase(Type primitiveType)
{
// Should hold as we rely on this to pass arrays of these
// types to native code
unsafe
{
Debug.Assert(sizeof(ushort) == sizeof(Float16));
Debug.Assert(sizeof(ushort) == sizeof(BFloat16));
}
_primitiveType = primitiveType;
}
/// <summary>
/// Query TensorTypeInfo by one of the supported types
/// </summary>
/// <param name="type"></param>
/// <returns>TensorTypeInfo or null if not supported</returns>
public static TensorTypeInfo GetTypeInfo(Type type)
{
TensorTypeInfo result = null;
typeInfoMap.TryGetValue(type, out result);
return result;
}
/// <summary>
/// Query TensorElementTypeInfo by enum
/// </summary>
/// <param name="elementType">type enum</param>
/// <returns>instance of TensorElementTypeInfo or null if not found</returns>
public static TensorElementTypeInfo GetElementTypeInfo(TensorElementType elementType)
{
TensorElementTypeInfo result = null;
tensorElementTypeInfoMap.TryGetValue(elementType, out result);
return result;
}
/// <summary>
/// Query TensorTypeInfo using this Tensor type
/// </summary>
/// <returns></returns>
public TensorTypeInfo GetTypeInfo()
{
return GetTypeInfo(_primitiveType);
}
}
/// <summary>
/// Various methods for creating and manipulating Tensor&lt;T&gt;
/// </summary>
public static partial class Tensor
{
/// <summary>
/// Creates an identity tensor of the specified size. An identity tensor is a two dimensional tensor with 1s in the diagonal.
/// </summary>
/// <typeparam name="T">type contained within the Tensor. Typically a value type such as int, double, float, etc.</typeparam>
/// <param name="size">Width and height of the identity tensor to create.</param>
/// <returns>a <paramref name="size"/> by <paramref name="size"/> with 1s along the diagonal and zeros elsewhere.</returns>
public static Tensor<T> CreateIdentity<T>(int size)
{
return CreateIdentity(size, false, Tensor<T>.One);
}
/// <summary>
/// Creates an identity tensor of the specified size and layout (row vs column major). An identity tensor is a two dimensional tensor with 1s in the diagonal.
/// </summary>
/// <typeparam name="T">type contained within the Tensor. Typically a value type such as int, double, float, etc.</typeparam>
/// <param name="size">Width and height of the identity tensor to create.</param>
/// <param name="columMajor">>False to indicate that the first dimension is most minor (closest) and the last dimension is most major (farthest): row-major. True to indicate that the last dimension is most minor (closest together) and the first dimension is most major (farthest apart): column-major.</param>
/// <returns>a <paramref name="size"/> by <paramref name="size"/> with 1s along the diagonal and zeros elsewhere.</returns>
public static Tensor<T> CreateIdentity<T>(int size, bool columMajor)
{
return CreateIdentity(size, columMajor, Tensor<T>.One);
}
/// <summary>
/// Creates an identity tensor of the specified size and layout (row vs column major) using the specified one value. An identity tensor is a two dimensional tensor with 1s in the diagonal. This may be used in case T is a type that doesn't have a known 1 value.
/// </summary>
/// <typeparam name="T">type contained within the Tensor. Typically a value type such as int, double, float, etc.</typeparam>
/// <param name="size">Width and height of the identity tensor to create.</param>
/// <param name="columMajor">>False to indicate that the first dimension is most minor (closest) and the last dimension is most major (farthest): row-major. True to indicate that the last dimension is most minor (closest together) and the first dimension is most major (farthest apart): column-major.</param>
/// <param name="oneValue">Value of <typeparamref name="T"/> that is used along the diagonal.</param>
/// <returns>a <paramref name="size"/> by <paramref name="size"/> with 1s along the diagonal and zeros elsewhere.</returns>
public static Tensor<T> CreateIdentity<T>(int size, bool columMajor, T oneValue)
{
unsafe
{
Span<int> dimensions = stackalloc int[2];
dimensions[0] = dimensions[1] = size;
var result = new DenseTensor<T>(dimensions, columMajor);
for (int i = 0; i < size; i++)
{
result.SetValue(i * size + i, oneValue);
}
return result;
}
}
/// <summary>
/// Creates a n+1-rank tensor using the specified n-rank diagonal. Values not on the diagonal will be filled with zeros.
/// </summary>
/// <typeparam name="T">type contained within the Tensor. Typically a value type such as int, double, float, etc.</typeparam>
/// <param name="diagonal">Tensor representing the diagonal to build the new tensor from.</param>
/// <returns>A new tensor of the same layout and order as <paramref name="diagonal"/> of one higher rank, with the values of <paramref name="diagonal"/> along the diagonal and zeros elsewhere.</returns>
public static Tensor<T> CreateFromDiagonal<T>(Tensor<T> diagonal)
{
return CreateFromDiagonal(diagonal, 0);
}
/// <summary>
/// Creates a n+1-dimension tensor using the specified n-dimension diagonal at the specified offset
/// from the center. Values not on the diagonal will be filled with zeros.
/// </summary>
/// <typeparam name="T">
/// type contained within the Tensor. Typically a value type such as int, double, float, etc.</typeparam>
/// <param name="diagonal">Tensor representing the diagonal to build the new tensor from.</param>
/// <param name="offset">Offset of diagonal to set in returned tensor. 0 for the main diagonal,
/// less than zero for diagonals below, greater than zero from diagonals above.</param>
/// <returns>A new tensor of the same layout and order as <paramref name="diagonal"/> of one higher rank,
/// with the values of <paramref name="diagonal"/> along the specified diagonal and zeros elsewhere.</returns>
public static Tensor<T> CreateFromDiagonal<T>(Tensor<T> diagonal, int offset)
{
if (diagonal.Rank < 1)
{
throw new ArgumentException($"Tensor {nameof(diagonal)} must have at least one dimension.", nameof(diagonal));
}
int diagonalLength = diagonal.dimensions[0];
// TODO: allow specification of axis1 and axis2?
var rank = diagonal.dimensions.Length + 1;
Span<int> dimensions = rank < ArrayUtilities.StackallocMax ? stackalloc int[rank] : new int[rank];
// assume square
var axisLength = diagonalLength + Math.Abs(offset);
dimensions[0] = dimensions[1] = axisLength;
for (int i = 1; i < diagonal.dimensions.Length; i++)
{
dimensions[i + 1] = diagonal.dimensions[i];
}
var result = diagonal.CloneEmpty(dimensions);
var sizePerDiagonal = diagonal.Length / diagonalLength;
var diagProjectionStride = diagonal.IsReversedStride && diagonal.Rank > 1 ? diagonal.strides[1] : 1;
var resultProjectionStride = result.IsReversedStride && result.Rank > 2 ? result.strides[2] : 1;
for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++)
{
var resultIndex0 = offset < 0 ? diagIndex - offset : diagIndex;
var resultIndex1 = offset > 0 ? diagIndex + offset : diagIndex;
var resultBase = resultIndex0 * result.strides[0] + resultIndex1 * result.strides[1];
var diagBase = diagIndex * diagonal.strides[0];
for (int diagProjectionOffset = 0; diagProjectionOffset < sizePerDiagonal; diagProjectionOffset++)
{
result.SetValue(resultBase + diagProjectionOffset * resultProjectionStride,
diagonal.GetValue(diagBase + diagProjectionOffset * diagProjectionStride));
}
}
return result;
}
}
/// <summary>
/// Represents a multi-dimensional collection of objects of type T that can be accessed by indices.
/// </summary>
/// <typeparam name="T">type contained within the Tensor. Typically a value type such as int, double, float, etc.</typeparam>
[DebuggerDisplay("{GetArrayString(false)}")]
// When we cross-compile for frameworks that expose ICloneable this must implement ICloneable as well.
public abstract class Tensor<T> : TensorBase, IList, IList<T>, IReadOnlyList<T>, IStructuralComparable, IStructuralEquatable
{
internal static T Zero
{
get
{
if (typeof(T) == typeof(bool))
{
return (T)(object)(false);
}
else if (typeof(T) == typeof(byte))
{
return (T)(object)(byte)(0);
}
else if (typeof(T) == typeof(char))
{
return (T)(object)(char)(0);
}
else if (typeof(T) == typeof(decimal))
{
return (T)(object)(decimal)(0);
}
else if (typeof(T) == typeof(double))
{
return (T)(object)(double)(0);
}
else if (typeof(T) == typeof(float))
{
return (T)(object)(float)(0);
}
else if (typeof(T) == typeof(int))
{
return (T)(object)(int)(0);
}
else if (typeof(T) == typeof(long))
{
return (T)(object)(long)(0);
}
else if (typeof(T) == typeof(sbyte))
{
return (T)(object)(sbyte)(0);
}
else if (typeof(T) == typeof(short))
{
return (T)(object)(short)(0);
}
else if (typeof(T) == typeof(uint))
{
return (T)(object)(uint)(0);
}
else if (typeof(T) == typeof(ulong))
{
return (T)(object)(ulong)(0);
}
else if (typeof(T) == typeof(ushort))
{
return (T)(object)(ushort)(0);
}
else if (typeof(T) == typeof(Float16))
{
return (T)(object)(ushort)(0);
}
else if (typeof(T) == typeof(BFloat16))
{
return (T)(object)(ushort)(0);
}
throw new NotSupportedException();
}
}
internal static T One
{
get
{
if (typeof(T) == typeof(bool))
{
return (T)(object)(true);
}
else if (typeof(T) == typeof(byte))
{
return (T)(object)(byte)(1);
}
else if (typeof(T) == typeof(char))
{
return (T)(object)(char)(1);
}
else if (typeof(T) == typeof(decimal))
{
return (T)(object)(decimal)(1);
}
else if (typeof(T) == typeof(double))
{
return (T)(object)(double)(1);
}
else if (typeof(T) == typeof(float))
{
return (T)(object)(float)(1);
}
else if (typeof(T) == typeof(int))
{
return (T)(object)(int)(1);
}
else if (typeof(T) == typeof(long))
{
return (T)(object)(long)(1);
}
else if (typeof(T) == typeof(sbyte))
{
return (T)(object)(sbyte)(1);
}
else if (typeof(T) == typeof(short))
{
return (T)(object)(short)(1);
}
else if (typeof(T) == typeof(uint))
{
return (T)(object)(uint)(1);
}
else if (typeof(T) == typeof(ulong))
{
return (T)(object)(ulong)(1);
}
else if (typeof(T) == typeof(ushort))
{
return (T)(object)(ushort)(1);
}
else if(typeof(T) == typeof(Float16))
{
return (T)(object)(ushort)(15360);
}
else if (typeof(T) == typeof(BFloat16))
{
return (T)(object)(ushort)(16256);
}
throw new NotSupportedException();
}
}
internal readonly int[] dimensions;
internal readonly int[] strides;
private readonly bool isReversedStride;
private readonly long length;
/// <summary>
/// Initialize a 1-dimensional tensor of the specified length
/// </summary>
/// <param name="length">Size of the 1-dimensional tensor</param>
protected Tensor(int length) : base(typeof(T))
{
dimensions = new[] { length };
strides = new[] { 1 };
isReversedStride = false;
this.length = length;
}
/// <summary>
/// Initialize an n-dimensional tensor with the specified dimensions and layout. ReverseStride=true gives a stride of 1-element width to the first dimension (0). ReverseStride=false gives a stride of 1-element width to the last dimension (n-1).
/// </summary>
/// <param name="dimensions">An span of integers that represent the size of each dimension of the Tensor to create.</param>
/// <param name="reverseStride">False (default) to indicate that the first dimension is most major (farthest apart) and the last dimension is most minor (closest together): akin to row-major in a rank-2 tensor. True to indicate that the last dimension is most major (farthest apart) and the first dimension is most minor (closest together): akin to column-major in a rank-2 tensor.</param>
protected Tensor(ReadOnlySpan<int> dimensions, bool reverseStride) : base(typeof(T))
{
if (dimensions == null)
{
throw new ArgumentNullException(nameof(dimensions));
}
this.dimensions = new int[dimensions.Length];
long size = 1;
for (int i = 0; i < dimensions.Length; i++)
{
if (dimensions[i] < 0)
{
throw new ArgumentOutOfRangeException(nameof(dimensions), "Dimensions must be non-negative");
}
this.dimensions[i] = dimensions[i];
size *= dimensions[i];
}
strides = ArrayUtilities.GetStrides(dimensions, reverseStride);
isReversedStride = reverseStride;
length = size;
}
/// <summary>
/// Initializes tensor with same dimensions as array, content of array is ignored.
/// ReverseStride=true gives a stride of 1-element width to the first dimension (0).
/// ReverseStride=false gives a stride of 1-element width to the last dimension (n-1).
/// </summary>
/// <param name="fromArray">Array from which to derive dimensions.</param>
/// <param name="reverseStride">
/// False (default) to indicate that the first dimension is most major (farthest apart) and the
/// last dimension is most minor (closest together): akin to row-major in a rank-2 tensor.
/// True to indicate that the last dimension is most major (farthest apart) and the first dimension
/// is most minor (closest together): akin to column-major in a rank-2 tensor.</param>
protected Tensor(Array fromArray, bool reverseStride) : base(typeof(T))
{
if (fromArray == null)
{
throw new ArgumentNullException(nameof(fromArray));
}
dimensions = new int[fromArray.Rank];
long size = 1;
for (int i = 0; i < dimensions.Length; i++)
{
dimensions[i] = fromArray.GetLength(i);
size *= dimensions[i];
}
strides = ArrayUtilities.GetStrides(dimensions, reverseStride);
isReversedStride = reverseStride;
length = size;
}
/// <summary>
/// Total length of the Tensor.
/// </summary>
public long Length => length;
/// <summary>
/// Rank of the tensor: number of dimensions.
/// </summary>
public int Rank => dimensions.Length;
/// <summary>
/// True if strides are reversed (AKA Column-major)
/// </summary>
public bool IsReversedStride => isReversedStride;
/// <summary>
/// Returns a readonly view of the dimensions of this tensor.
/// </summary>
public ReadOnlySpan<int> Dimensions => dimensions;
/// <summary>
/// Returns a readonly view of the strides of this tensor.
/// </summary>
public ReadOnlySpan<int> Strides => strides;
/// <summary>
/// Sets all elements in Tensor to <paramref name="value"/>.
/// </summary>
/// <param name="value">Value to fill</param>
public virtual void Fill(T value)
{
for (int i = 0; i < Length; i++)
{
SetValue(i, value);
}
}
/// <summary>
/// Creates a shallow copy of this tensor, with new backing storage.
/// </summary>
/// <returns>A shallow copy of this tensor.</returns>
public abstract Tensor<T> Clone();
/// <summary>
/// Creates a new Tensor with the same layout and dimensions as this tensor with elements initialized to their default value.
/// </summary>
/// <returns>A new Tensor with the same layout and dimensions as this tensor with elements initialized to their default value.</returns>
public virtual Tensor<T> CloneEmpty()
{
return CloneEmpty<T>(dimensions);
}
/// <summary>
/// Creates a new Tensor with the specified dimensions and the same layout as this tensor with elements initialized to their default value.
/// </summary>
/// <param name="dimensions">An span of integers that represent the size of each dimension of the DenseTensor to create.</param>
/// <returns>A new Tensor with the same layout as this tensor and specified <paramref name="dimensions"/> with elements initialized to their default value.</returns>
public virtual Tensor<T> CloneEmpty(ReadOnlySpan<int> dimensions)
{
return CloneEmpty<T>(dimensions);
}
/// <summary>
/// Creates a new Tensor of a different type with the same layout and size as this tensor with elements initialized to their default value.
/// </summary>
/// <typeparam name="TResult">Type contained within the new Tensor. Typically a value type such as int, double, float, etc.</typeparam>
/// <returns>A new Tensor with the same layout and dimensions as this tensor with elements of <typeparamref name="TResult"/> type initialized to their default value.</returns>
public virtual Tensor<TResult> CloneEmpty<TResult>()
{
return CloneEmpty<TResult>(dimensions);
}
/// <summary>
/// Creates a new Tensor of a different type with the specified dimensions and the same layout as this tensor with elements initialized to their default value.
/// </summary>
/// <typeparam name="TResult">Type contained within the new Tensor. Typically a value type such as int, double, float, etc.</typeparam>
/// <param name="dimensions">An span of integers that represent the size of each dimension of the DenseTensor to create.</param>
/// <returns>A new Tensor with the same layout as this tensor of specified <paramref name="dimensions"/> with elements of <typeparamref name="TResult"/> type initialized to their default value.</returns>
public abstract Tensor<TResult> CloneEmpty<TResult>(ReadOnlySpan<int> dimensions);
/// <summary>
/// Gets the n-1 dimension diagonal from the n dimension tensor.
/// </summary>
/// <returns>An n-1 dimension tensor with the values from the main diagonal of this tensor.</returns>
public Tensor<T> GetDiagonal()
{
return GetDiagonal(0);
}
/// <summary>
/// Gets the n-1 dimension diagonal from the n dimension tensor at the specified offset from center.
/// </summary>
/// <param name="offset">Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above.</param>
/// <returns>An n-1 dimension tensor with the values from the specified diagonal of this tensor.</returns>
public Tensor<T> GetDiagonal(int offset)
{
// Get diagonal of first two dimensions for all remaining dimensions
// diagnonal is as follows:
// { 1, 2, 4 }
// { 8, 3, 9 }
// { 0, 7, 5 }
// The diagonal at offset 0 is { 1, 3, 5 }
// The diagonal at offset 1 is { 2, 9 }
// The diagonal at offset -1 is { 8, 7 }
if (Rank < 2)
{
throw new InvalidOperationException($"Cannot compute diagonal of {nameof(Tensor<T>)} with Rank less than 2.");
}
// TODO: allow specification of axis1 and axis2?
var axisLength0 = dimensions[0];
var axisLength1 = dimensions[1];
// the diagonal will be the length of the smaller axis
// if offset it positive, the length will shift along the second axis
// if the offsett is negative, the length will shift along the first axis
// In that way the length of the diagonal will be
// Min(offset < 0 ? axisLength0 + offset : axisLength0, offset > 0 ? axisLength1 - offset : axisLength1)
// To illustrate, consider the following
// { 1, 2, 4, 3, 7 }
// { 8, 3, 9, 2, 6 }
// { 0, 7, 5, 2, 9 }
// The diagonal at offset 0 is { 1, 3, 5 }, Min(3, 5) = 3
// The diagonal at offset 1 is { 2, 9, 2 }, Min(3, 5 - 1) = 3
// The diagonal at offset 3 is { 3, 6 }, Min(3, 5 - 3) = 2
// The diagonal at offset -1 is { 8, 7 }, Min(3 - 1, 5) = 2
var offsetAxisLength0 = offset < 0 ? axisLength0 + offset : axisLength0;
var offsetAxisLength1 = offset > 0 ? axisLength1 - offset : axisLength1;
var diagonalLength = Math.Min(offsetAxisLength0, offsetAxisLength1);
if (diagonalLength <= 0)
{
throw new ArgumentException($"Cannot compute diagonal with offset {offset}", nameof(offset));
}
var newTensorRank = Rank - 1;
var newTensorDimensions = newTensorRank < ArrayUtilities.StackallocMax ? stackalloc int[newTensorRank] : new int[newTensorRank];
newTensorDimensions[0] = diagonalLength;
for (int i = 2; i < dimensions.Length; i++)
{
newTensorDimensions[i - 1] = dimensions[i];
}
var diagonalTensor = CloneEmpty(newTensorDimensions);
var sizePerDiagonal = diagonalTensor.Length / diagonalTensor.Dimensions[0];
var diagProjectionStride = diagonalTensor.IsReversedStride && diagonalTensor.Rank > 1 ? diagonalTensor.strides[1] : 1;
var sourceProjectionStride = IsReversedStride && Rank > 2 ? strides[2] : 1;
for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++)
{
var sourceIndex0 = offset < 0 ? diagIndex - offset : diagIndex;
var sourceIndex1 = offset > 0 ? diagIndex + offset : diagIndex;
var sourceBase = sourceIndex0 * strides[0] + sourceIndex1 * strides[1];
var diagBase = diagIndex * diagonalTensor.strides[0];
for (int diagProjectionIndex = 0; diagProjectionIndex < sizePerDiagonal; diagProjectionIndex++)
{
diagonalTensor.SetValue(diagBase + diagProjectionIndex * diagProjectionStride,
GetValue(sourceBase + diagProjectionIndex * sourceProjectionStride));
}
}
return diagonalTensor;
}
/// <summary>
/// Gets a tensor representing the elements below and including the diagonal, with the rest of the elements zero-ed.
/// </summary>
/// <returns>A tensor with the values from this tensor at and below the main diagonal and zeros elsewhere.</returns>
public Tensor<T> GetTriangle()
{
return GetTriangle(0, upper: false);
}
/// <summary>
/// Gets a tensor representing the elements below and including the specified diagonal, with the rest of the elements zero-ed.
/// </summary>
/// <param name="offset">Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above.</param>
/// <returns>A tensor with the values from this tensor at and below the specified diagonal and zeros elsewhere.</returns>
public Tensor<T> GetTriangle(int offset)
{
return GetTriangle(offset, upper: false);
}
/// <summary>
/// Gets a tensor representing the elements above and including the diagonal, with the rest of the elements zero-ed.
/// </summary>
/// <returns>A tensor with the values from this tensor at and above the main diagonal and zeros elsewhere.</returns>
public Tensor<T> GetUpperTriangle()
{
return GetTriangle(0, upper: true);
}
/// <summary>
/// Gets a tensor representing the elements above and including the specified diagonal, with the rest of the elements zero-ed.
/// </summary>
/// <param name="offset">Offset of diagonal to set in returned tensor. 0 for the main diagonal, less than zero for diagonals below, greater than zero from diagonals above.</param>
/// <returns>A tensor with the values from this tensor at and above the specified diagonal and zeros elsewhere.</returns>
public Tensor<T> GetUpperTriangle(int offset)
{
return GetTriangle(offset, upper: true);
}
/// <summary>
/// Implementation method for GetTriangle, GetLowerTriangle, GetUpperTriangle
/// </summary>
/// <param name="offset">Offset of diagonal to set in returned tensor.</param>
/// <param name="upper">true for upper triangular and false otherwise</param>
/// <returns></returns>
public Tensor<T> GetTriangle(int offset, bool upper)
{
if (Rank < 2)
{
throw new InvalidOperationException($"Cannot compute triangle of {nameof(Tensor<T>)} with Rank less than 2.");
}
// Similar to get diagonal except it gets every element below and including the diagonal.
// TODO: allow specification of axis1 and axis2?
var axisLength0 = dimensions[0];
var axisLength1 = dimensions[1];
var diagonalLength = Math.Max(axisLength0, axisLength1);
var result = CloneEmpty();
var projectionSize = Length / (axisLength0 * axisLength1);
var projectionStride = IsReversedStride && Rank > 2 ? strides[2] : 1;
for (int diagIndex = 0; diagIndex < diagonalLength; diagIndex++)
{
// starting point for the tri
var triIndex0 = offset > 0 ? diagIndex - offset : diagIndex;
var triIndex1 = offset > 0 ? diagIndex : diagIndex + offset;
// for lower triangle, iterate index0 keeping same index1
// for upper triangle, iterate index1 keeping same index0
if (triIndex0 < 0)
{
if (upper)
{
// out of bounds, ignore this diagIndex.
continue;
}
else
{
// set index to 0 so that we can iterate on the remaining index0 values.
triIndex0 = 0;
}
}
if (triIndex1 < 0)
{
if (upper)
{
// set index to 0 so that we can iterate on the remaining index1 values.
triIndex1 = 0;
}
else
{
// out of bounds, ignore this diagIndex.
continue;
}
}
while ((triIndex1 < axisLength1) && (triIndex0 < axisLength0))
{
var baseIndex = triIndex0 * strides[0] + triIndex1 * result.strides[1];
for (int projectionIndex = 0; projectionIndex < projectionSize; projectionIndex++)
{
var index = baseIndex + projectionIndex * projectionStride;
result.SetValue(index, GetValue(index));
}
if (upper)
{
triIndex1++;
}
else
{
triIndex0++;
}
}
}
return result;
}
/// <summary>
/// Reshapes the current tensor to new dimensions, using the same backing storage if possible.
/// </summary>
/// <param name="dimensions">An span of integers that represent the size of each dimension of the Tensor to create.</param>
/// <returns>A new tensor that reinterprets this tensor with different dimensions.</returns>
public abstract Tensor<T> Reshape(ReadOnlySpan<int> dimensions);
/// <summary>
/// Obtains the value at the specified indices
/// </summary>
/// <param name="indices">A one-dimensional array of integers that represent the indices specifying the position of the element to get.</param>
/// <returns>The value at the specified position in this Tensor.</returns>
public virtual T this[params int[] indices]
{
get
{
if (indices == null)
{
throw new ArgumentNullException(nameof(indices));
}
var span = new ReadOnlySpan<int>(indices);
return this[span];
}
set
{
if (indices == null)
{
throw new ArgumentNullException(nameof(indices));
}
var span = new ReadOnlySpan<int>(indices);
this[span] = value;
}
}
/// <summary>
/// Obtains the value at the specified indices
/// </summary>
/// <param name="indices">A span integers that represent the indices specifying the position of the element to get.</param>
/// <returns>The value at the specified position in this Tensor.</returns>
public virtual T this[ReadOnlySpan<int> indices]
{
get
{
return GetValue(ArrayUtilities.GetIndex(strides, indices));
}
set
{
SetValue(ArrayUtilities.GetIndex(strides, indices), value);
}
}
/// <summary>
/// Gets the value at the specied index, where index is a linearized version of n-dimension indices using strides.
/// </summary>
/// <param name="index">An integer index computed as a dot-product of indices.</param>
/// <returns>The value at the specified position in this Tensor.</returns>
public abstract T GetValue(int index);
/// <summary>
/// Sets the value at the specied index, where index is a linearized version of n-dimension indices using strides.
/// </summary>
/// <param name="index">An integer index computed as a dot-product of indices.</param>
/// <param name="value">The new value to set at the specified position in this Tensor.</param>
public abstract void SetValue(int index, T value);
#region statics
/// <summary>
/// Performs a value comparison of the content and shape of two tensors. Two tensors are equal if they have the same shape and same value at every set of indices. If not equal a tensor is greater or less than another tensor based on the first non-equal element when enumerating in linear order.
/// </summary>
/// <param name="left"></param>
/// <param name="right"></param>
/// <returns></returns>
public static int Compare(Tensor<T> left, Tensor<T> right)
{
return StructuralComparisons.StructuralComparer.Compare(left, right);
}
/// <summary>
/// Performs a value equality comparison of the content of two tensors. Two tensors are equal if they have the same shape and same value at every set of indices.
/// </summary>
/// <param name="left"></param>
/// <param name="right"></param>
/// <returns></returns>
public static bool Equals(Tensor<T> left, Tensor<T> right)
{
return StructuralComparisons.StructuralEqualityComparer.Equals(left, right);
}
#endregion
#region IEnumerable members
IEnumerator IEnumerable.GetEnumerator()
{
return ((IEnumerable<T>)this).GetEnumerator();
}
#endregion
#region ICollection members
int ICollection.Count => (int)Length;
bool ICollection.IsSynchronized => false;
object ICollection.SyncRoot => this; // backingArray.this?
void ICollection.CopyTo(Array array, int index)
{
if (array is T[] destinationArray)
{
CopyTo(destinationArray, index);
}
else
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if (array.Rank != 1)
{
throw new ArgumentException("Only single dimensional arrays are supported for the requested action.", nameof(array));
}
if (array.Length < index + Length)
{
throw new ArgumentException("The number of elements in the Tensor is greater than the available space from index to the end of the destination array.", nameof(array));
}
for (int i = 0; i < length; i++)
{
array.SetValue(GetValue(i), index + i);
}
}
}
#endregion
#region IList members
object IList.this[int index]
{
get
{
return GetValue(index);
}
set
{
try
{
SetValue(index, (T)value);
}
catch (InvalidCastException)
{
throw new ArgumentException($"The value \"{value}\" is not of type \"{typeof(T)}\" and cannot be used in this generic collection.");
}
}
}
/// <summary>
/// Always fixed size Tensor
/// </summary>
/// <value>always true</value>
public bool IsFixedSize => true;
/// <summary>
/// Tensor is not readonly
/// </summary>
/// <value>always false</value>
public bool IsReadOnly => false;
int IList.Add(object value)
{
throw new InvalidOperationException();
}
void IList.Clear()
{
Fill(default(T));
}
bool IList.Contains(object value)
{
if (IsCompatibleObject(value))
{
return Contains((T)value);
}
return false;
}
int IList.IndexOf(object value)
{
if (IsCompatibleObject(value))
{
return IndexOf((T)value);
}
return -1;
}
void IList.Insert(int index, object value)
{
throw new InvalidOperationException();
}
void IList.Remove(object value)
{
throw new InvalidOperationException();
}
void IList.RemoveAt(int index)
{
throw new InvalidOperationException();
}
#endregion
#region IEnumerable<T> members
IEnumerator<T> IEnumerable<T>.GetEnumerator()
{
for (int i = 0; i < Length; i++)
{
yield return GetValue(i);
}
}
#endregion
#region ICollection<T> members
int ICollection<T>.Count => (int)Length;
void ICollection<T>.Add(T item)
{
throw new InvalidOperationException();
}
void ICollection<T>.Clear()
{
Fill(default(T));
}
bool ICollection<T>.Contains(T item)
{
return Contains(item);
}
/// <summary>
/// Determines whether an element is in the Tensor&lt;T&gt;.
/// </summary>
/// <param name="item">
/// The object to locate in the Tensor&lt;T&gt;. The value can be null for reference types.
/// </param>
/// <returns>
/// true if item is found in the Tensor&lt;T&gt;; otherwise, false.
/// </returns>
protected virtual bool Contains(T item)
{
return Length != 0 && IndexOf(item) != -1;
}
void ICollection<T>.CopyTo(T[] array, int arrayIndex)
{
CopyTo(array, arrayIndex);
}
/// <summary>
/// Copies the elements of the Tensor&lt;T&gt; to an Array, starting at a particular Array index.
/// </summary>
/// <param name="array">
/// The one-dimensional Array that is the destination of the elements copied from Tensor&lt;T&gt;. The Array must have zero-based indexing.
/// </param>
/// <param name="arrayIndex">
/// The zero-based index in array at which copying begins.
/// </param>
protected virtual void CopyTo(T[] array, int arrayIndex)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if (array.Length < arrayIndex + Length)
{
throw new ArgumentException("The number of elements in the Tensor is greater than the available space from index to the end of the destination array.", nameof(array));
}
for (int i = 0; i < length; i++)
{
array[arrayIndex + i] = GetValue(i);
}
}
bool ICollection<T>.Remove(T item)
{
throw new InvalidOperationException();
}
#endregion
#region IReadOnlyCollection<T> members
int IReadOnlyCollection<T>.Count => (int)Length;
#endregion
#region IList<T> members
T IList<T>.this[int index]
{
get { return GetValue(index); }
set { SetValue(index, value); }
}
int IList<T>.IndexOf(T item)
{
return IndexOf(item);
}
/// <summary>
/// Determines the index of a specific item in the Tensor&lt;T&gt;.
/// </summary>
/// <param name="item">The object to locate in the Tensor&lt;T&gt;.</param>
/// <returns>The index of item if found in the tensor; otherwise, -1.</returns>
protected virtual int IndexOf(T item)
{
for (int i = 0; i < Length; i++)
{
if (GetValue(i).Equals(item))
{
return i;
}
}
return -1;
}
void IList<T>.Insert(int index, T item)
{
throw new InvalidOperationException();
}
void IList<T>.RemoveAt(int index)
{
throw new InvalidOperationException();
}
#endregion
#region IReadOnlyList<T> members
T IReadOnlyList<T>.this[int index] => GetValue(index);
#endregion
#region IStructuralComparable members
int IStructuralComparable.CompareTo(object other, IComparer comparer)
{
if (other == null)
{
return 1;
}
if (other is Tensor<T>)
{
return CompareTo((Tensor<T>)other, comparer);
}
var otherArray = other as Array;
if (otherArray != null)
{
return CompareTo(otherArray, comparer);
}
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} to {other.GetType()}.", nameof(other));
}
private int CompareTo(Tensor<T> other, IComparer comparer)
{
if (Rank != other.Rank)
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} with Rank {Rank} to {nameof(other)} with Rank {other.Rank}.", nameof(other));
}
for (int i = 0; i < dimensions.Length; i++)
{
if (dimensions[i] != other.dimensions[i])
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)}s with differning dimension {i}, {dimensions[i]} != {other.dimensions[i]}.", nameof(other));
}
}
int result = 0;
if (IsReversedStride == other.IsReversedStride)
{
for (int i = 0; i < Length; i++)
{
result = comparer.Compare(GetValue(i), other.GetValue(i));
if (result != 0)
{
break;
}
}
}
else
{
var indices = Rank < ArrayUtilities.StackallocMax ? stackalloc int[Rank] : new int[Rank];
for (int i = 0; i < Length; i++)
{
ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices);
result = comparer.Compare(this[indices], other[indices]);
if (result != 0)
{
break;
}
}
}
return result;
}
private int CompareTo(Array other, IComparer comparer)
{
if (Rank != other.Rank)
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} with Rank {Rank} to {nameof(Array)} with rank {other.Rank}.", nameof(other));
}
for (int i = 0; i < dimensions.Length; i++)
{
var otherDimension = other.GetLength(i);
if (dimensions[i] != otherDimension)
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} to {nameof(Array)} with differning dimension {i}, {dimensions[i]} != {otherDimension}.", nameof(other));
}
}
int result = 0;
var indices = new int[Rank];
for (int i = 0; i < Length; i++)
{
ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices);
result = comparer.Compare(GetValue(i), other.GetValue(indices));
if (result != 0)
{
break;
}
}
return result;
}
#endregion
#region IStructuralEquatable members
bool IStructuralEquatable.Equals(object other, IEqualityComparer comparer)
{
if (other == null)
{
return false;
}
if (other is Tensor<T>)
{
return Equals((Tensor<T>)other, comparer);
}
var otherArray = other as Array;
if (otherArray != null)
{
return Equals(otherArray, comparer);
}
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} to {other.GetType()}.", nameof(other));
}
private bool Equals(Tensor<T> other, IEqualityComparer comparer)
{
if (Rank != other.Rank)
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} with Rank {Rank} to {nameof(other)} with Rank {other.Rank}.", nameof(other));
}
for (int i = 0; i < dimensions.Length; i++)
{
if (dimensions[i] != other.dimensions[i])
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)}s with differning dimension {i}, {dimensions[i]} != {other.dimensions[i]}.", nameof(other));
}
}
if (IsReversedStride == other.IsReversedStride)
{
for (int i = 0; i < Length; i++)
{
if (!comparer.Equals(GetValue(i), other.GetValue(i)))
{
return false;
}
}
}
else
{
var indices = Rank < ArrayUtilities.StackallocMax ? stackalloc int[Rank] : new int[Rank];
for (int i = 0; i < Length; i++)
{
ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices);
if (!comparer.Equals(this[indices], other[indices]))
{
return false;
}
}
}
return true;
}
private bool Equals(Array other, IEqualityComparer comparer)
{
if (Rank != other.Rank)
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} with Rank {Rank} to {nameof(Array)} with rank {other.Rank}.", nameof(other));
}
for (int i = 0; i < dimensions.Length; i++)
{
var otherDimension = other.GetLength(i);
if (dimensions[i] != otherDimension)
{
throw new ArgumentException($"Cannot compare {nameof(Tensor<T>)} to {nameof(Array)} with differning dimension {i}, {dimensions[i]} != {otherDimension}.", nameof(other));
}
}
var indices = new int[Rank];
for (int i = 0; i < Length; i++)
{
ArrayUtilities.GetIndices(strides, IsReversedStride, i, indices);
if (!comparer.Equals(GetValue(i), other.GetValue(indices)))
{
return false;
}
}
return true;
}
int IStructuralEquatable.GetHashCode(IEqualityComparer comparer)
{
int hashCode = 0;
// this ignores shape, which is fine it just means we'll have hash collisions for things
// with the same content and different shape.
for (int i = 0; i < Length; i++)
{
hashCode ^= comparer.GetHashCode(GetValue(i));
}
return hashCode;
}
#endregion
#region Translations
/// <summary>
/// Creates a copy of this tensor as a DenseTensor&lt;T&gt;. If this tensor is already a DenseTensor&lt;T&gt; calling this method is equivalent to calling Clone().
/// </summary>
/// <returns></returns>
public virtual DenseTensor<T> ToDenseTensor()
{
var denseTensor = new DenseTensor<T>(Dimensions, IsReversedStride);
for (int i = 0; i < Length; i++)
{
denseTensor.SetValue(i, GetValue(i));
}
return denseTensor;
}
#endregion
/// <summary>
/// Get a string representation of Tensor
/// </summary>
/// <param name="includeWhitespace"></param>
/// <returns></returns>
public string GetArrayString(bool includeWhitespace = true)
{
var builder = new StringBuilder();
var strides = ArrayUtilities.GetStrides(dimensions);
var indices = new int[Rank];
var innerDimension = Rank - 1;
var innerLength = dimensions[innerDimension];
var outerLength = Length / innerLength;
int indent = 0;
for (int outerIndex = 0; outerIndex < Length; outerIndex += innerLength)
{
ArrayUtilities.GetIndices(strides, false, outerIndex, indices);
while ((indent < innerDimension) && (indices[indent] == 0))
{
// start up
if (includeWhitespace)
{
Indent(builder, indent);
}
indent++;
builder.Append('{');
if (includeWhitespace)
{
builder.AppendLine();
}
}
for (int innerIndex = 0; innerIndex < innerLength; innerIndex++)
{
indices[innerDimension] = innerIndex;
if ((innerIndex == 0))
{
if (includeWhitespace)
{
Indent(builder, indent);
}
builder.Append('{');
}
else
{
builder.Append(',');
}
builder.Append(this[indices]);
}
builder.Append('}');
for (int i = Rank - 2; i >= 0; i--)
{
var lastIndex = dimensions[i] - 1;
if (indices[i] == lastIndex)
{
// close out
--indent;
if (includeWhitespace)
{
builder.AppendLine();
Indent(builder, indent);
}
builder.Append('}');
}
else
{
builder.Append(',');
if (includeWhitespace)
{
builder.AppendLine();
}
break;
}
}
}
return builder.ToString();
}
private static void Indent(StringBuilder builder, int tabs, int spacesPerTab = 4)
{
for (int tab = 0; tab < tabs; tab++)
{
for (int space = 0; space < spacesPerTab; space++)
{
builder.Append(' ');
}
}
}
private static bool IsCompatibleObject(object value)
{
// Non-null values are fine. Only accept nulls if T is a class or Nullable<T>.
// Note that default(T) is not equal to null for value types except when T is Nullable<T>.
return ((value is T) || (value == null && default(T) == null));
}
}
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_ON_DEVICE__
/// <summary>
/// Holds the Checkpoint State as generated/consumed by on-device training APIs
/// </summary>
public class CheckpointState : SafeHandle
{
internal IntPtr Handle
{
get
{
return handle;
}
}
/// <summary>
/// Creates CheckpointState by loading state from path.
/// <param name="checkpointPath"> absolute path to checkpoint file.</param>
/// </summary>
public CheckpointState(string checkpointPath)
: base(IntPtr.Zero, true)
{
if (NativeTrainingMethods.TrainingEnabled())
{
var envHandle = OrtEnv.Handle; // just so it is initialized
LoadCheckpoint(checkpointPath);
}
else
{
throw new InvalidOperationException("Training is disabled in the current build");
}
}
/// <summary>
/// Overrides SafeHandle.IsInvalid
/// </summary>
/// <value>returns true if handle is equal to Zero</value>
public override bool IsInvalid { get { return handle == IntPtr.Zero; } }
/// <summary>
/// Loads Checkpoint state from path
/// </summary>
/// <param name="checkpointPath"> absolute path to checkpoint</param>
private void LoadCheckpoint(string checkpointPath)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out handle));
}
#region SafeHandle
/// <summary>
/// Overrides SafeHandle.ReleaseHandle() to properly dispose of
/// the native instance of CheckpointState
/// </summary>
/// <returns>always returns true</returns>
protected override bool ReleaseHandle()
{
NativeTrainingMethods.OrtReleaseCheckpointState(handle);
handle = IntPtr.Zero;
return true;
}
#endregion
}
#endif
}
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_ON_DEVICE__
// NOTE: The order of the APIs in this struct should match exactly that in
// OrtTrainingApi (onnxruntime_training_c_api.cc)
[StructLayout(LayoutKind.Sequential)]
public struct OrtTrainingApi
{
public IntPtr LoadCheckpoint;
public IntPtr SaveCheckpoint;
public IntPtr CreateTrainingSession;
public IntPtr TrainingSessionGetTrainingModelOutputCount;
public IntPtr TrainingSessionGetEvalModelOutputCount;
public IntPtr TrainingSessionGetTrainingModelOutputName;
public IntPtr TrainingSessionGetEvalModelOutputName;
public IntPtr ResetGrad;
public IntPtr TrainStep;
public IntPtr EvalStep;
public IntPtr SetLearningRate;
public IntPtr GetLearningRate;
public IntPtr OptimizerStep;
public IntPtr RegisterLinearLRScheduler;
public IntPtr SchedulerStep;
public IntPtr GetParametersSize;
public IntPtr CopyParametersToBuffer;
public IntPtr CopyBufferToParameters;
public IntPtr ReleaseTrainingSession;
public IntPtr ReleaseCheckpointState;
public IntPtr ExportModelForInferencing;
}
internal static class NativeTrainingMethods
{
static OrtApi api_;
static OrtTrainingApi trainingApi_;
static IntPtr trainingApiPtr;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate ref OrtApi DOrtGetApi(UInt32 version);
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtTrainingApi* */ DOrtGetTrainingApi(UInt32 version);
public static DOrtGetTrainingApi OrtGetTrainingApi;
static NativeTrainingMethods()
{
DOrtGetApi OrtGetApi = (DOrtGetApi)Marshal.GetDelegateForFunctionPointer(NativeMethods.OrtGetApiBase().GetApi, typeof(DOrtGetApi));
// TODO: Make this save the pointer, and not copy the whole structure across
api_ = (OrtApi)OrtGetApi(4 /*ORT_API_VERSION*/);
OrtGetTrainingApi = (DOrtGetTrainingApi)Marshal.GetDelegateForFunctionPointer(api_.GetTrainingApi, typeof(DOrtGetTrainingApi));
trainingApiPtr = OrtGetTrainingApi(4 /*ORT_API_VERSION*/);
if (trainingApiPtr != IntPtr.Zero)
{
trainingApi_ = (OrtTrainingApi)Marshal.PtrToStructure(trainingApiPtr, typeof(OrtTrainingApi));
OrtLoadCheckpoint = (DOrtLoadCheckpoint)Marshal.GetDelegateForFunctionPointer(trainingApi_.LoadCheckpoint, typeof(DOrtLoadCheckpoint));
OrtSaveCheckpoint = (DOrtSaveCheckpoint)Marshal.GetDelegateForFunctionPointer(trainingApi_.SaveCheckpoint, typeof(DOrtSaveCheckpoint));
OrtCreateTrainingSession = (DOrtCreateTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.CreateTrainingSession, typeof(DOrtCreateTrainingSession));
OrtGetTrainingModelOutputCount = (DOrtGetTrainingModelOutputCount)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetTrainingModelOutputCount, typeof(DOrtGetTrainingModelOutputCount));
OrtGetEvalModelOutputCount = (DOrtGetEvalModelOutputCount)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelOutputCount, typeof(DOrtGetEvalModelOutputCount));
OrtGetTrainingModelOutputName = (DOrtGetTrainingModelOutputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetTrainingModelOutputName, typeof(DOrtGetTrainingModelOutputName));
OrtGetEvalModelOutputName = (DOrtGetEvalModelOutputName)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainingSessionGetEvalModelOutputName, typeof(DOrtGetEvalModelOutputName));
OrtResetGrad = (DOrtResetGrad)Marshal.GetDelegateForFunctionPointer(trainingApi_.ResetGrad, typeof(DOrtResetGrad));
OrtTrainStep = (DOrtTrainStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.TrainStep, typeof(DOrtTrainStep));
OrtEvalStep = (DOrtEvalStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.EvalStep, typeof(DOrtEvalStep));
OrtSetLearningRate = (DOrtSetLearningRate)Marshal.GetDelegateForFunctionPointer(trainingApi_.SetLearningRate, typeof(DOrtSetLearningRate));
OrtGetLearningRate = (DOrtGetLearningRate)Marshal.GetDelegateForFunctionPointer(trainingApi_.GetLearningRate, typeof(DOrtGetLearningRate));
OrtOptimizerStep = (DOrtOptimizerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.OptimizerStep, typeof(DOrtOptimizerStep));
OrtRegisterLinearLRScheduler = (DOrtRegisterLinearLRScheduler)Marshal.GetDelegateForFunctionPointer(trainingApi_.RegisterLinearLRScheduler, typeof(DOrtRegisterLinearLRScheduler));
OrtSchedulerStep = (DOrtSchedulerStep)Marshal.GetDelegateForFunctionPointer(trainingApi_.SchedulerStep, typeof(DOrtSchedulerStep));
OrtReleaseTrainingSession = (DOrtReleaseTrainingSession)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseTrainingSession, typeof(DOrtReleaseTrainingSession));
OrtReleaseCheckpointState = (DOrtReleaseCheckpointState)Marshal.GetDelegateForFunctionPointer(trainingApi_.ReleaseCheckpointState, typeof(DOrtReleaseCheckpointState));
}
}
#region TrainingSession API
/// <summary>
/// Creates an instance of OrtSession with provided parameters
/// </summary>
/// <param name="checkpointPath">checkpoint string path</param>
/// <param name="checkpointState">(Output) Loaded OrtCheckpointState instance</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtLoadCheckpoint(
byte[] checkpointPath,
out IntPtr /* (OrtCheckpointState**) */ checkpointState);
public static DOrtLoadCheckpoint OrtLoadCheckpoint;
/// <summary>
/// Creates an instance of OrtSession with provided parameters
/// </summary>
/// <param name="checkpointPath">checkpoint string path</param>
/// <param name="checkpointState">(Output) Loaded OrtCheckpointState instance</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtSaveCheckpoint(
byte[] checkpointPath,
IntPtr /*(OrtTrainingSession*)*/ session,
bool saveOptimizerState);
public static DOrtSaveCheckpoint OrtSaveCheckpoint;
/// <summary>
/// Creates an instance of OrtSession with provided parameters
/// </summary>
/// <param name="environment">Native OrtEnv instance</param>
/// <param name="sessionOptions">Native SessionOptions instance</param>
/// <param name="checkpointState">Loaded OrtCheckpointState instance</param>
/// <param name="trainModelPath">model string path</param>
/// <param name="evalModelPath">model string path</param>
/// <param name="optimizerModelPath">model string path</param>
/// <param name="session">(Output) Created native OrtTrainingSession instance</param>
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /* OrtStatus* */DOrtCreateTrainingSession(
IntPtr /* (OrtEnv*) */ environment,
IntPtr /* (OrtSessionOptions*) */ sessionOptions,
IntPtr /* (OrtCheckpointState*) */ checkpointState,
byte[] trainModelPath,
byte[] evalModelPath,
byte[] optimizerModelPath,
out IntPtr /* (OrtTrainingSession**) */ session);
public static DOrtCreateTrainingSession OrtCreateTrainingSession;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTrainingModelOutputCount(
IntPtr /*(OrtSession*)*/ session,
out UIntPtr count);
public static DOrtGetTrainingModelOutputCount OrtGetTrainingModelOutputCount;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetEvalModelOutputCount(
IntPtr /*(OrtSession*)*/ session,
out UIntPtr count);
public static DOrtGetEvalModelOutputCount OrtGetEvalModelOutputCount;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetTrainingModelOutputName(
IntPtr /*(OrtSession*)*/ session,
UIntPtr index,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/name);
public static DOrtGetTrainingModelOutputName OrtGetTrainingModelOutputName;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtGetEvalModelOutputName(
IntPtr /*(OrtSession*)*/ session,
UIntPtr index,
IntPtr /*(OrtAllocator*)*/ allocator,
out IntPtr /*(char**)*/name);
public static DOrtGetEvalModelOutputName OrtGetEvalModelOutputName;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(OrtStatus*)*/ DOrtResetGrad(
IntPtr /*(OrtSession*)*/ session);
public static DOrtResetGrad OrtResetGrad;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtTrainStep(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
UIntPtr inputCount,
IntPtr[] /* (OrtValue*[])*/ inputValues,
UIntPtr outputCount,
IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */
);
public static DOrtTrainStep OrtTrainStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtEvalStep(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions, // can be null to use the default options
UIntPtr inputCount,
IntPtr[] /* (OrtValue*[])*/ inputValues,
UIntPtr outputCount,
IntPtr[] outputValues /* An array of output value pointers. Array must be allocated by the caller */
);
public static DOrtEvalStep OrtEvalStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtOptimizerStep(
IntPtr /*(OrtTrainingSession*)*/ session,
IntPtr /*(OrtSessionRunOptions*)*/ runOptions // can be null to use the default options
);
public static DOrtOptimizerStep OrtOptimizerStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtSetLearningRate(
IntPtr /*(OrtTrainingSession*)*/ session,
float learningRate
);
public static DOrtSetLearningRate OrtSetLearningRate;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtGetLearningRate(
IntPtr /*(OrtTrainingSession*)*/ session,
out float learningRate
);
public static DOrtGetLearningRate OrtGetLearningRate;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtRegisterLinearLRScheduler(
IntPtr /*(OrtTrainingSession*)*/ session,
long warmupStepCount,
long totalStepCount,
float learningRate
);
public static DOrtRegisterLinearLRScheduler OrtRegisterLinearLRScheduler;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate IntPtr /*(ONNStatus*)*/ DOrtSchedulerStep(
IntPtr /*(OrtTrainingSession*)*/ session
);
public static DOrtSchedulerStep OrtSchedulerStep;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseTrainingSession(IntPtr /*(OrtSession*)*/session);
public static DOrtReleaseTrainingSession OrtReleaseTrainingSession;
[UnmanagedFunctionPointer(CallingConvention.Winapi)]
public delegate void DOrtReleaseCheckpointState(IntPtr /*(OrtSession*)*/session);
public static DOrtReleaseCheckpointState OrtReleaseCheckpointState;
#endregion TrainingSession API
public static bool TrainingEnabled()
{
if (trainingApiPtr == IntPtr.Zero)
{
return false;
}
return true;
}
} //class NativeTrainingMethods
#endif
} //namespace
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
using System;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
namespace Microsoft.ML.OnnxRuntime
{
#if __ENABLE_TRAINING_ON_DEVICE__
enum LRScheduler
{
None = 0,
Constant = 1,
Linear = 2
}
/// <summary>
/// Represents a Training Session on an ONNX Model.
/// This is a IDisposable class and it must be disposed of
/// using either a explicit call to Dispose() method or
/// a pattern of using() block. If this is a member of another
/// class that class must also become IDisposable and it must
/// dispose of TrainingSession in its Dispose() method.
/// </summary>
public class TrainingSession : IDisposable
{
/// <summary>
/// A pointer to a underlying native instance of OrtTrainingSession
/// </summary>
private IntPtr _nativeHandle;
private ulong _trainOutputCount;
private ulong _evalOutputCount;
private List<string> _trainOutputNames;
private List<string> _evalOutputNames;
private SessionOptions _builtInSessionOptions = null;
private RunOptions _builtInRunOptions = null;
private LRScheduler _scheduler = LRScheduler.None;
private bool _disposed = false;
#region Public API
/// <summary>
/// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
/// </summary>
/// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
/// <param name="trainModelPath">Specify path to training model graph.</param>
/// <param name="evalModelPath">Specify path to eval model graph.</param>
/// <param name="optimizerModelPath">Specify path to optimizer model graph.</param>
public TrainingSession(CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
{
Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(evalModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath));
}
/// <summary>
/// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
/// </summary>
/// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
/// <param name="trainModelPath">Specify path to training model graph.</param>
/// <param name="optimizerModelPath">Specify path to optimizer model graph.</param>
public TrainingSession(CheckpointState state, string trainModelPath, string optimizerModelPath)
{
Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), null, NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath));
}
/// <summary>
/// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
/// </summary>
/// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
/// <param name="trainModelPath">Specify path to training model graph.</param>
public TrainingSession(CheckpointState state, string trainModelPath)
{
Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), null, null);
}
/// <summary>
/// Creates TrainingSession from the model and checkpoint in <paramref name="state"/>.
/// </summary>
/// <param name="options">Session options</param>
/// <param name="state">Model checkpoint loaded into <see cref="CheckpointState"/>.</param>
/// <param name="trainModelPath">Specify path to training model graph.</param>
/// <param name="evalModelPath">Specify path to eval model graph.</param>
/// <param name="optimizerModelPath">Specify path to optimizer model graph.</param>
public TrainingSession(SessionOptions options, CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath)
{
Init(options, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(evalModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath));
}
/// <summary>
/// Runs a train step on the loaded model for the given inputs.
/// </summary>
/// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
/// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
public void TrainStep(
IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
{
TrainStep(_builtInRunOptions, inputValues, outputValues);
}
/// <summary>
/// Runs a train step on the loaded model for the given inputs. Uses the given RunOptions for this run.
/// </summary>
/// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
/// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
/// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
public void TrainStep(
RunOptions options,
IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
{
if (_trainOutputCount!= (ulong)outputValues.Count())
{
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
}
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}
/// <summary>
/// Runs the loaded model for the given inputs, and fetches the graph outputs.
/// </summary>
/// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
/// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(
IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
{
using (var ortValues = new DisposableList<OrtValue>((int)_trainOutputCount))
{
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount];
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, _builtInRunOptions.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray));
foreach (var v in outputValuesArray)
{
ortValues.Add(new OrtValue(v));
}
var result = new DisposableList<DisposableNamedOnnxValue>(_trainOutputNames.Count);
try
{
for (int i = 0; i < ortValues.Count; i++)
{
var ortValue = ortValues[i];
result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(_trainOutputNames[i], ortValue));
}
}
catch (OnnxRuntimeException)
{
result.Dispose();
throw;
}
return result;
}
}
/// <summary>
/// Runs the loaded model for the given inputs, and fetches the specified outputs in <paramref name="outputNames"/>. Uses the given RunOptions for this run.
/// </summary>
/// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
/// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
/// <returns>Output Tensors in a Collection of NamedOnnxValue. User must dispose the output.</returns>
public IDisposableReadOnlyCollection<DisposableNamedOnnxValue> TrainStep(
RunOptions options,
IReadOnlyCollection<FixedBufferOnnxValue> inputValues)
{
using (var ortValues = new DisposableList<OrtValue>((int)_trainOutputCount))
{
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount];
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray));
foreach (var v in outputValuesArray)
{
ortValues.Add(new OrtValue(v));
}
var result = new DisposableList<DisposableNamedOnnxValue>(_trainOutputNames.Count);
try
{
for (int i = 0; i < ortValues.Count; i++)
{
var ortValue = ortValues[i];
result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(_trainOutputNames[i], ortValue));
}
}
catch (OnnxRuntimeException)
{
result.Dispose();
throw;
}
return result;
}
}
/// <summary>
/// Sets the reset grad flag on the training graph. The gradient buffers will be reset while executing the
/// next train step.
/// </summary>
public void ResetGrad()
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtResetGrad(_nativeHandle));
}
/// <summary>
/// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation.
/// </summary>
/// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
/// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
public void EvalStep(
IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
{
EvalStep(_builtInRunOptions, inputValues, outputValues);
}
/// <summary>
/// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation.
/// </summary>
/// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
/// <param name="inputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the input values.</param>
/// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
public void EvalStep(
RunOptions options,
IReadOnlyCollection<FixedBufferOnnxValue> inputValues,
IReadOnlyCollection<FixedBufferOnnxValue> outputValues)
{
if (!_evalOutputCount.Equals(outputValues.Count))
{
throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount}).");
}
IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true);
IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count,
inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray));
}
/// <summary>
/// Sets a constant learning rate for the session. LR must be controlled by either this method
/// or by registering a LR scheduler.
/// </summary>
public void SetLearningRate(float learningRate)
{
if (_scheduler != LRScheduler.None && _scheduler != LRScheduler.Constant)
{
throw new InvalidOperationException("Cannot set constant LR while using LR scheduler.");
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSetLearningRate(_nativeHandle, learningRate));
_scheduler = LRScheduler.Constant;
}
/// <summary>
/// Gets the current learning rate for the session.
/// </summary>
public float GetLearningRate()
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetLearningRate(_nativeHandle, out float lr));
return lr;
}
/// <summary>
/// Registers a linear learning rate scheduler for the session. LR must be controlled by either
/// the SetLearningRate method or by registering a LR scheduler.
/// <param name="warmupStepCount"> Number of warmup steps</param>
/// <param name="totalStepCount"> Number of total steps</param>
/// <param name="initialLearningRate"> Initial learning rate</param>
/// </summary>
public void RegisterLinearLRScheduler(long warmupStepCount,
long totalStepCount,
float initialLearningRate)
{
if (_scheduler != LRScheduler.None && _scheduler != LRScheduler.Constant)
{
throw new InvalidOperationException("Cannot set LR scheduler while using constant LR.");
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtRegisterLinearLRScheduler(_nativeHandle, warmupStepCount,totalStepCount, initialLearningRate));
_scheduler = LRScheduler.Linear;
}
/// <summary>
/// Runs a LR scheduler step. There must be a valid LR scheduler registered for the training session.
/// </summary>
public void SchedulerStep()
{
if (_scheduler == LRScheduler.Constant || _scheduler == LRScheduler.None)
{
throw new InvalidOperationException("Cannot take scheduler step without registering a valid LR scheduler.");
}
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSchedulerStep(_nativeHandle));
}
/// <summary>
/// Runs an optimizer step on the loaded model for the given inputs. The optimizer graph must be passed while TrainingSession creation.
/// </summary>
public void OptimizerStep()
{
OptimizerStep(_builtInRunOptions);
}
/// <summary>
/// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation.
/// </summary>
/// <param name="options">Specify <see cref="RunOptions"/> for step.</param>
/// <param name="outputValues">Specify a collection of <see cref="FixedBufferOnnxValue"/> that indicates the output values.</param>
public void OptimizerStep(RunOptions options)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtOptimizerStep(_nativeHandle, options.Handle));
}
/// <summary>
/// Saves a checkpoint to path. It can be loaded into <see cref="CheckpointState"/>
/// </summary>
/// <param name="path">Specify path for saving the checkpoint.</param>
/// <param name="saveOptimizerState">SFlag indicating whether to save optimizer state or not.</param>
public void SaveCheckpoint(string path, bool saveOptimizerState = false)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(path), _nativeHandle, saveOptimizerState));
}
#endregion
#region private methods
private void Init(SessionOptions sessOptions, CheckpointState state, byte[] trainModelPath, byte[] evalModelPath, byte[] optimizerModelPath)
{
if (!NativeTrainingMethods.TrainingEnabled())
{
throw new InvalidOperationException("Training is disabled in the current build.");
}
var options = sessOptions;
if (sessOptions == null)
{
_builtInSessionOptions = new SessionOptions();
options = _builtInSessionOptions;
}
var envHandle = OrtEnv.Handle;
try
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCreateTrainingSession(envHandle, options.Handle, state.Handle, trainModelPath,
evalModelPath, optimizerModelPath, out _nativeHandle));
UIntPtr outputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelOutputCount(_nativeHandle, out outputCount));
_trainOutputCount = outputCount.ToUInt64();
// get all the output names and metadata
_trainOutputNames = new List<string>();
for (ulong i = 0; i < _trainOutputCount; i++)
{
_trainOutputNames.Add(GetOutputName(i, true));
}
if (evalModelPath != null)
{
outputCount = UIntPtr.Zero;
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelOutputCount(_nativeHandle, out outputCount));
_evalOutputCount = outputCount.ToUInt64();
_evalOutputNames = new List<string>();
for (ulong i = 0; i < _evalOutputCount; i++)
{
_evalOutputNames.Add(GetOutputName(i, false));
}
}
_builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call
}
catch (Exception)
{
CleanupHelper(true);
throw;
}
}
private string GetOutputName(ulong index, bool training)
{
var allocator = OrtAllocator.DefaultInstance;
IntPtr nameHandle;
string str = null;
if (training)
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelOutputName(
_nativeHandle,
(UIntPtr)index,
allocator.Pointer,
out nameHandle));
}
else
{
NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelOutputName(
_nativeHandle,
(UIntPtr)index,
allocator.Pointer,
out nameHandle));
}
using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0))
{
str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle);
}
return str;
}
private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection<FixedBufferOnnxValue> values, bool input)
{
var valuesArray = new IntPtr[values.Count];
for (int index = 0; index < values.Count; ++index)
{
var v = values.ElementAt(index);
if (!input && v.ElementType == Tensors.TensorElementType.String)
{
throw new NotSupportedException("Using string type FixedBufferOnnxValue in outputs is not supported.");
}
valuesArray[index] = v.Value.Handle;
}
return valuesArray;
}
/// <summary>
/// Other classes access
/// </summary>
internal IntPtr Handle
{
get
{
return _nativeHandle;
}
}
#endregion
#region IDisposable
/// <summary>
/// Finalizer.
/// </summary>
~TrainingSession()
{
Dispose(false);
}
/// <summary>
/// IDisposable implementation
/// </summary>
public void Dispose()
{
Dispose(true);
GC.SuppressFinalize(this);
}
/// <summary>
/// IDisposable implementation
/// </summary>
/// <param name="disposing">true if invoked from Dispose() method</param>
protected virtual void Dispose(bool disposing)
{
if (_disposed)
{
return;
}
CleanupHelper(disposing);
_disposed = true;
}
private void CleanupHelper(bool disposing)
{
if (disposing)
{
if (_builtInRunOptions != null)
{
_builtInRunOptions.Dispose();
_builtInRunOptions = null;
}
if (_builtInSessionOptions != null)
{
_builtInSessionOptions.Dispose();
_builtInSessionOptions = null;
}
}
// cleanup unmanaged resources
if (_nativeHandle != IntPtr.Zero)
{
NativeTrainingMethods.OrtReleaseTrainingSession(_nativeHandle);
_nativeHandle = IntPtr.Zero;
}
}
#endregion
}
#endif
}
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Condition=" '$(AndroidApplication)'=='true' ">
<AndroidLibrary Include="$(MSBuildThisFileDirectory)..\..\runtimes\android\native\*">
<Link>%(Filename)%(Extension)</Link>
</AndroidLibrary>
</ItemGroup>
</Project>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Condition=" '$(AndroidApplication)'=='true' ">
<AndroidLibrary Bind="false" Include="$(MSBuildThisFileDirectory)..\..\runtimes\android\native\*">
<Link>%(Filename)%(Extension)</Link>
</AndroidLibrary>
</ItemGroup>
</Project>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Condition="('$(OutputType)'!='Library' OR '$(IsAppExtension)'=='True')">
<NativeReference Include="$(MSBuildThisFileDirectory)..\..\runtimes\ios\native\onnxruntime.xcframework">
<Kind>Static</Kind>
<IsCxx>True</IsCxx>
<SmartLink>True</SmartLink>
<ForceLoad>True</ForceLoad>
<LinkerFlags>-lc++</LinkerFlags>
<WeakFrameworks>CoreML</WeakFrameworks>
</NativeReference>
</ItemGroup>
</Project>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Condition="('$(OutputType)'!='Library' OR '$(IsAppExtension)'=='True')">
<NativeReference Condition="'$(Platform)' == 'arm64'" Include="$(MSBuildThisFileDirectory)..\..\runtimes\osx.10.14-arm64\native\libonnxruntime.dylib">
<Kind>Framework</Kind>
<IsCxx>True</IsCxx>
<SmartLink>True</SmartLink>
<ForceLoad>True</ForceLoad>
<LinkerFlags>-lc++</LinkerFlags>
<WeakFrameworks>CoreML</WeakFrameworks>
</NativeReference>
<NativeReference Condition="'$(Platform)' == 'x64'" Include="$(MSBuildThisFileDirectory)..\..\runtimes\osx.10.14-x64\native\libonnxruntime.dylib">
<Kind>Framework</Kind>
<IsCxx>True</IsCxx>
<SmartLink>True</SmartLink>
<ForceLoad>True</ForceLoad>
<LinkerFlags>-lc++</LinkerFlags>
<WeakFrameworks>CoreML</WeakFrameworks>
</NativeReference>
</ItemGroup>
</Project>
\ No newline at end of file
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemDefinitionGroup>
<ClCompile>
<AdditionalIncludeDirectories>$(MSBuildThisFileDirectory)../../build/native/include/;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ClCompile>
<ResourceCompile>
<AdditionalIncludeDirectories>$(MSBuildThisFileDirectory)../../build/native/include/;%(AdditionalIncludeDirectories)</AdditionalIncludeDirectories>
</ResourceCompile>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(PlatformTarget)' == 'ARM64'">
<Link>
<AdditionalDependencies>$(MSBuildThisFileDirectory)../../runtimes/win-arm64/native/onnxruntime.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(PlatformTarget)' == 'ARM'">
<Link>
<AdditionalDependencies>$(MSBuildThisFileDirectory)../../runtimes/win-arm/native/onnxruntime.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')">
<Link>
<AdditionalDependencies>$(MSBuildThisFileDirectory)../../runtimes/win-x64/native/onnxruntime.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<ItemDefinitionGroup Condition="'$(PlatformTarget)' == 'x86' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' == 'true')">
<Link>
<AdditionalDependencies>$(MSBuildThisFileDirectory)../../runtimes/win-x86/native/onnxruntime.lib;%(AdditionalDependencies)</AdditionalDependencies>
</Link>
</ItemDefinitionGroup>
<PropertyGroup>
<EnginePlatform Condition="'$(Platform)' == 'Win32'">x86</EnginePlatform>
<EnginePlatform Condition="'$(Platform)' == 'ARM64'">arm64</EnginePlatform>
<EnginePlatform Condition="'$(Platform)' == 'ARM'">arm</EnginePlatform>
<EnginePlatform Condition="'$(Platform)' != 'Win32' AND '$(Platform)' != 'ARM64'">$(Platform)</EnginePlatform>
</PropertyGroup>
<PropertyGroup>
<OnnxRuntimeBinary>$(MSBuildThisFileDirectory)..\..\runtimes\win-$(EnginePlatform)\native\onnxruntime.dll</OnnxRuntimeBinary>
</PropertyGroup>
<!-- Assume apps using the Microsoft.ML.OnnxRuntime.DirectML package only want the DirectML binaries (no need for a build dependency). -->
<PropertyGroup Label="Globals" Condition="Exists('$(MSBuildThisFileDirectory)include\dml_provider_factory.h')">
<Microsoft_AI_DirectML_SkipDebugLayerCopy>true</Microsoft_AI_DirectML_SkipDebugLayerCopy>
<Microsoft_AI_DirectML_SkipLink>true</Microsoft_AI_DirectML_SkipLink>
<Microsoft_AI_DirectML_SkipIncludeDir>true</Microsoft_AI_DirectML_SkipIncludeDir>
</PropertyGroup>
<ItemGroup>
<!-- x64 -->
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime.dll"
Condition="'$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')">
<Link>onnxruntime.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_shared.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_shared.dll')">
<Link>onnxruntime_providers_shared.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_cuda.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_cuda.dll')">
<Link>onnxruntime_providers_cuda.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_dnnl.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_dnnl.dll')">
<Link>onnxruntime_providers_dnnl.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_tensorrt.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_tensorrt.dll')">
<Link>onnxruntime_providers_tensorrt.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_openvino.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\onnxruntime_providers_openvino.dll')">
<Link>onnxruntime_providers_openvino.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\dnnl.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\dnnl.dll')">
<Link>dnnl.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\mklml.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\mklml.dll')">
<Link>mklml.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\libiomp5md.dll"
Condition="('$(PlatformTarget)' == 'x64' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' != 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x64\native\libiomp5md.dll')">
<Link>libiomp5md.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<!-- arm64 -->
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime.dll"
Condition="'$(PlatformTarget)' == 'ARM64'">
<Link>onnxruntime.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime_providers_shared.dll"
Condition="'$(PlatformTarget)' == 'ARM64' AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm64\native\onnxruntime_providers_shared.dll')">
<Link>onnxruntime_providers_shared.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<!-- arm -->
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-arm\native\onnxruntime.dll"
Condition="'$(PlatformTarget)' == 'ARM'">
<Link>onnxruntime.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-arm\native\onnxruntime_providers_shared.dll"
Condition="'$(PlatformTarget)' == 'ARM' AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-arm\native\onnxruntime_providers_shared.dll')">
<Link>onnxruntime_providers_shared.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<!-- x86 -->
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x86\native\onnxruntime.dll"
Condition="('$(PlatformTarget)' == 'x86' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' == 'true'))">
<Link>onnxruntime.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x86\native\dnnl.dll"
Condition="('$(PlatformTarget)' == 'x86' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' == 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x86\native\dnnl.dll')">
<Link>dnnl.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x86\native\mklml.dll"
Condition="('$(PlatformTarget)' == 'x86' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' == 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x86\native\mklml.dll')">
<Link>mklml.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
<None Include="$(MSBuildThisFileDirectory)..\..\runtimes\win-x86\native\libiomp5md.dll"
Condition="('$(PlatformTarget)' == 'x86' OR ('$(PlatformTarget)' == 'AnyCPU' AND '$(Prefer32Bit)' == 'true')) AND
Exists('$(MSBuildThisFileDirectory)..\..\runtimes\win-x86\native\libiomp5md.dll')">
<Link>libiomp5md.dll</Link>
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Visible>false</Visible>
</None>
</ItemGroup>
</Project>
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<Target Name="Microsoft_ML_OnnxRuntime_CheckPrerequisites" BeforeTargets="BeforeBuild">
<!--
Special case .NET Core portable applications. When building a portable .NET Core app,
the PlatformTarget is empty, and you don't know until runtime (i.e. which dotnet.exe)
what processor architecture will be used.
-->
<Error Condition="('$(PlatformTarget)' != 'x64' AND '$(PlatformTarget)' != 'arm32' AND '$(PlatformTarget)' != 'arm64' AND '$(PlatformTarget)' != 'x86' AND '$(PlatformTarget)' != 'AnyCPU') AND
('$(OutputType)' == 'Exe' OR '$(OutputType)'=='WinExe') AND
!('$(TargetFrameworkIdentifier)' == '.NETCoreApp' AND '$(PlatformTarget)' == '') AND
('$(TargetFrameworkIdentifier)' != 'Xamarin.iOS' AND
$([MSBuild]::GetTargetPlatformIdentifier('$(TargetFramework)')) != 'ios') AND
'$(SuppressOnnxRuntimePlatformCompatibilityError)' != 'true'"
Text="Microsoft.ML.OnnxRuntime only supports the AnyCPU, x64, arm32, arm64 and x86 platforms at this time."/>
</Target>
</Project>
<?xml version="1.0" encoding="utf-8"?>
<Project ToolsVersion="4.0" xmlns="http://schemas.microsoft.com/developer/msbuild/2003">
<ItemGroup Condition="('$(OutputType)'!='Library' OR '$(IsAppExtension)'=='True')">
<NativeReference Condition="'$(Platform)' == 'iPhoneSimulator'" Include="$(MSBuildThisFileDirectory)..\..\runtimes\ios\native\onnxruntime.xcframework\ios-arm64_x86_64-simulator\onnxruntime.framework">
<Kind>Framework</Kind>
<IsCxx>True</IsCxx>
<SmartLink>True</SmartLink>
<ForceLoad>True</ForceLoad>
<LinkerFlags>-lc++</LinkerFlags>
<WeakFrameworks>CoreML</WeakFrameworks>
</NativeReference>
<NativeReference Condition="'$(Platform)' == 'iPhone'" Include="$(MSBuildThisFileDirectory)..\..\runtimes\ios\native\onnxruntime.xcframework\ios-arm64\onnxruntime.framework">
<Kind>Framework</Kind>
<IsCxx>True</IsCxx>
<SmartLink>True</SmartLink>
<ForceLoad>True</ForceLoad>
<LinkerFlags>-lc++</LinkerFlags>
<WeakFrameworks>CoreML</WeakFrameworks>
</NativeReference>
</ItemGroup>
</Project>
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment