// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using Microsoft.ML.OnnxRuntime.Tensors; using System; using System.Runtime.InteropServices; using System.Text; namespace Microsoft.ML.OnnxRuntime { /// /// This class enables binding of inputs and/or outputs to pre-allocated /// memory. This enables interesting scenarios. For example, if your input /// already resides in some pre-allocated memory like GPU, you can bind /// that piece of memory to an input name and shape and onnxruntime will use that as input. /// Other traditional inputs can also be bound that already exists as Tensors. /// /// Note, that this arrangement is designed to minimize data copies and to that effect /// your memory allocations must match what is expected by the model, whether you run on /// CPU or GPU. Data copy will still be made, if your pre-allocated memory location does not /// match the one expected by the model. However, copies with OrtIoBindings are only done once, /// at the time of the binding, not at run time. This means, that if your input data required a copy, /// your further input modifications would not be seen by onnxruntime unless you rebind it, even if it is /// the same buffer. If you require the scenario where data is copied, OrtIOBinding may not be the best match /// for your use case. /// /// The fact that data copy is not made during runtime also has performance implications. /// public class OrtIoBinding : SafeHandle { /// /// Use InferenceSession.CreateIoBinding() /// /// internal OrtIoBinding(InferenceSession session) : base(IntPtr.Zero, true) { NativeApiStatus.VerifySuccess(NativeMethods.OrtCreateIoBinding(session.Handle, out handle)); } internal IntPtr Handle { get { return handle; } } /// /// Overrides SafeHandle.IsInvalid /// /// returns true if handle is equal to Zero public override bool IsInvalid { get { return handle == IntPtr.Zero; } } /// /// Bind a piece of pre-allocated native memory as a OrtValue Tensor with a given shape /// to an input with a given name. The model will read the specified input from that memory /// possibly avoiding the need to copy between devices. OrtMemoryAllocation continues to own /// the chunk of native memory, and the allocation should be alive until the end of execution. /// /// of the input /// Tensor element type /// /// native memory allocation public void BindInput(string name, Tensors.TensorElementType elementType, long[] shape, OrtMemoryAllocation allocation) { BindOrtAllocation(name, elementType, shape, allocation, true); } /// /// Bind externally (not from OrtAllocator) allocated memory as input. /// The model will read the specified input from that memory /// possibly avoiding the need to copy between devices. The user code continues to own /// the chunk of externally allocated memory, and the allocation should be alive until the end of execution. /// /// name /// non ort allocated memory public void BindInput(string name, OrtExternalAllocation allocation) { BindExternalAllocation(name, allocation, true); } /// /// Bind the input with the given name as an OrtValue Tensor allocated in pinned managed memory. /// Instance of FixedBufferOnnxValue owns the memory and should be alive until the end of execution. /// /// name of input /// public void BindInput(string name, FixedBufferOnnxValue fixedValue) { if (fixedValue.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Binding works only with Tensors"); } BindInputOrOutput(name, fixedValue.Value.Handle, true); } /// /// Blocks until device completes all preceding requested tasks. /// Useful for memory synchronization. /// public void SynchronizeBoundInputs() { NativeApiStatus.VerifySuccess(NativeMethods.OrtSynchronizeBoundInputs(handle)); } /// /// Bind model output to an OrtValue as Tensor with a given type and shape. An instance of OrtMemoryAllocaiton /// owns the memory and should be alive for the time of execution. /// /// of the output /// tensor element type /// tensor shape /// allocated memory public void BindOutput(string name, Tensors.TensorElementType elementType, long[] shape, OrtMemoryAllocation allocation) { BindOrtAllocation(name, elementType, shape, allocation, false); } /// /// Bind externally (not from OrtAllocator) allocated memory as output. /// The model will read the specified input from that memory /// possibly avoiding the need to copy between devices. The user code continues to own /// the chunk of externally allocated memory, and the allocation should be alive until the end of execution. /// /// name /// non ort allocated memory public void BindOutput(string name, OrtExternalAllocation allocation) { BindExternalAllocation(name, allocation, false); } /// /// Bind model output to a given instance of FixedBufferOnnxValue which owns the underlying /// pinned managed memory and should be alive for the time of execution. /// /// of the output /// public void BindOutput(string name, FixedBufferOnnxValue fixedValue) { if (fixedValue.OnnxValueType != OnnxValueType.ONNX_TYPE_TENSOR) { throw new OnnxRuntimeException(ErrorCode.InvalidArgument, "Binding works only with Tensors"); } BindInputOrOutput(name, fixedValue.Value.Handle, false); } /// /// This function will bind model output with the given name to a device /// specified by the memInfo. /// /// output name /// instance of memory info public void BindOutputToDevice(string name, OrtMemoryInfo memInfo) { var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned); using (var pinnedName = new PinnedGCHandle(utf8NamePinned)) NativeApiStatus.VerifySuccess(NativeMethods.OrtBindOutputToDevice(handle, pinnedName.Pointer, memInfo.Pointer)); } /// /// Blocks until device completes all preceding requested tasks. /// Useful for memory synchronization. /// public void SynchronizeBoundOutputs() { NativeApiStatus.VerifySuccess(NativeMethods.OrtSynchronizeBoundOutputs(handle)); } /// /// Bind allocation obtained from an Ort allocator /// /// name /// data type /// tensor shape /// ort allocation /// whether this is input or output private void BindOrtAllocation(string name, Tensors.TensorElementType elementType, long[] shape, OrtMemoryAllocation allocation, bool isInput) { using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info, elementType, shape, allocation.Pointer, allocation.Size)) BindInputOrOutput(name, ortValue.Handle, isInput); } /// /// Bind external allocation as input or output. /// The allocation is owned by the user code. /// /// name /// non ort allocated memory /// whether this is an input or output private void BindExternalAllocation(string name, OrtExternalAllocation allocation, bool isInput) { using (var ortValue = OrtValue.CreateTensorValueWithData(allocation.Info, allocation.ElementType, allocation.Shape, allocation.Pointer, allocation.Size)) BindInputOrOutput(name, ortValue.Handle, isInput); } /// /// Internal helper /// /// /// /// private void BindInputOrOutput(string name, IntPtr ortValue, bool isInput) { var utf8NamePinned = GCHandle.Alloc(NativeOnnxValueHelper.StringToZeroTerminatedUtf8(name), GCHandleType.Pinned); using (var pinnedName = new PinnedGCHandle(utf8NamePinned)) { if (isInput) { NativeApiStatus.VerifySuccess(NativeMethods.OrtBindInput(handle, pinnedName.Pointer, ortValue)); } else { NativeApiStatus.VerifySuccess(NativeMethods.OrtBindOutput(handle, pinnedName.Pointer, ortValue)); } } } /// /// Returns an array of output names in the same order they were bound /// /// array of output names public string[] GetOutputNames() { IntPtr buffer = IntPtr.Zero; IntPtr lengths = IntPtr.Zero; UIntPtr count = UIntPtr.Zero; var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetBoundOutputNames(handle, allocator.Pointer, out buffer, out lengths, out count)); if (count.Equals(UIntPtr.Zero)) { return new string[0]; } using (var bufferAllocation = new OrtMemoryAllocation(allocator, buffer, 0)) using (var lengthsAllocation = new OrtMemoryAllocation(allocator, lengths, 0)) { int outputCount = (int)count; var lens = new int[outputCount]; int totalLength = 0; for (int i = 0; i < outputCount; ++i) { var len = (int)Marshal.ReadIntPtr(lengths, IntPtr.Size * i); lens[i] = len; totalLength += len; } var stringData = new byte[totalLength]; Marshal.Copy(buffer, stringData, 0, stringData.Length); string[] result = new string[outputCount]; int readOffset = 0; for (int i = 0; i < outputCount; ++i) { var strLen = lens[i]; result[i] = Encoding.UTF8.GetString(stringData, readOffset, strLen); readOffset += strLen; } return result; } } /// /// This fetches bound outputs after running the model with RunWithBinding() /// /// IDisposableReadOnlyCollection public IDisposableReadOnlyCollection GetOutputValues() { IntPtr ortValues = IntPtr.Zero; UIntPtr count = UIntPtr.Zero; var allocator = OrtAllocator.DefaultInstance; NativeApiStatus.VerifySuccess(NativeMethods.OrtGetBoundOutputValues(handle, allocator.Pointer, out ortValues, out count)); if (count.Equals(UIntPtr.Zero)) { return new DisposableList(); } using (var ortValuesAllocation = new OrtMemoryAllocation(allocator, ortValues, 0)) { int outputCount = (int)count; var ortList = new DisposableList(outputCount); try { for (int i = 0; i < outputCount; ++i) { IntPtr ortValue = Marshal.ReadIntPtr(ortValues, IntPtr.Size * i); ortList.Add(new OrtValue(ortValue)); } } catch (Exception) { ortList.Dispose(); throw; } return ortList; } } /// /// Clear all bound inputs and start anew /// public void ClearBoundInputs() { NativeMethods.OrtClearBoundInputs(handle); } /// /// Clear all bound outputs /// public void ClearBoundOutputs() { NativeMethods.OrtClearBoundOutputs(handle); } #region SafeHandle /// /// Overrides SafeHandle.ReleaseHandle() to properly dispose of /// the native instance of OrtIoBidning /// /// always returns true protected override bool ReleaseHandle() { NativeMethods.OrtReleaseIoBinding(handle); handle = IntPtr.Zero; return true; } #endregion } }