// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using System; using System.Collections.Generic; using System.Linq; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { #if __ENABLE_TRAINING_ON_DEVICE__ enum LRScheduler { None = 0, Constant = 1, Linear = 2 } /// /// Represents a Training Session on an ONNX Model. /// This is a IDisposable class and it must be disposed of /// using either a explicit call to Dispose() method or /// a pattern of using() block. If this is a member of another /// class that class must also become IDisposable and it must /// dispose of TrainingSession in its Dispose() method. /// public class TrainingSession : IDisposable { /// /// A pointer to a underlying native instance of OrtTrainingSession /// private IntPtr _nativeHandle; private ulong _trainOutputCount; private ulong _evalOutputCount; private List _trainOutputNames; private List _evalOutputNames; private SessionOptions _builtInSessionOptions = null; private RunOptions _builtInRunOptions = null; private LRScheduler _scheduler = LRScheduler.None; private bool _disposed = false; #region Public API /// /// Creates TrainingSession from the model and checkpoint in . /// /// Model checkpoint loaded into . /// Specify path to training model graph. /// Specify path to eval model graph. /// Specify path to optimizer model graph. public TrainingSession(CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath) { Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(evalModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath)); } /// /// Creates TrainingSession from the model and checkpoint in . /// /// Model checkpoint loaded into . /// Specify path to training model graph. /// Specify path to optimizer model graph. public TrainingSession(CheckpointState state, string trainModelPath, string optimizerModelPath) { Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), null, NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath)); } /// /// Creates TrainingSession from the model and checkpoint in . /// /// Model checkpoint loaded into . /// Specify path to training model graph. public TrainingSession(CheckpointState state, string trainModelPath) { Init(null, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), null, null); } /// /// Creates TrainingSession from the model and checkpoint in . /// /// Session options /// Model checkpoint loaded into . /// Specify path to training model graph. /// Specify path to eval model graph. /// Specify path to optimizer model graph. public TrainingSession(SessionOptions options, CheckpointState state, string trainModelPath, string evalModelPath, string optimizerModelPath) { Init(options, state, NativeOnnxValueHelper.GetPlatformSerializedString(trainModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(evalModelPath), NativeOnnxValueHelper.GetPlatformSerializedString(optimizerModelPath)); } /// /// Runs a train step on the loaded model for the given inputs. /// /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void TrainStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { TrainStep(_builtInRunOptions, inputValues, outputValues); } /// /// Runs a train step on the loaded model for the given inputs. Uses the given RunOptions for this run. /// /// Specify for step. /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void TrainStep( RunOptions options, IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { if (_trainOutputCount!= (ulong)outputValues.Count()) { throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); } IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } /// /// Runs the loaded model for the given inputs, and fetches the graph outputs. /// /// Specify a collection of that indicates the input values. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection TrainStep( IReadOnlyCollection inputValues) { using (var ortValues = new DisposableList((int)_trainOutputCount)) { IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount]; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, _builtInRunOptions.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray)); foreach (var v in outputValuesArray) { ortValues.Add(new OrtValue(v)); } var result = new DisposableList(_trainOutputNames.Count); try { for (int i = 0; i < ortValues.Count; i++) { var ortValue = ortValues[i]; result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(_trainOutputNames[i], ortValue)); } } catch (OnnxRuntimeException) { result.Dispose(); throw; } return result; } } /// /// Runs the loaded model for the given inputs, and fetches the specified outputs in . Uses the given RunOptions for this run. /// /// Specify for step. /// Specify a collection of that indicates the input values. /// Output Tensors in a Collection of NamedOnnxValue. User must dispose the output. public IDisposableReadOnlyCollection TrainStep( RunOptions options, IReadOnlyCollection inputValues) { using (var ortValues = new DisposableList((int)_trainOutputCount)) { IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); IntPtr[] outputValuesArray = new IntPtr[(int)_trainOutputCount]; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)_trainOutputCount, outputValuesArray)); foreach (var v in outputValuesArray) { ortValues.Add(new OrtValue(v)); } var result = new DisposableList(_trainOutputNames.Count); try { for (int i = 0; i < ortValues.Count; i++) { var ortValue = ortValues[i]; result.Add(DisposableNamedOnnxValue.CreateFromOrtValue(_trainOutputNames[i], ortValue)); } } catch (OnnxRuntimeException) { result.Dispose(); throw; } return result; } } /// /// Sets the reset grad flag on the training graph. The gradient buffers will be reset while executing the /// next train step. /// public void ResetGrad() { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtResetGrad(_nativeHandle)); } /// /// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation. /// /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void EvalStep( IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { EvalStep(_builtInRunOptions, inputValues, outputValues); } /// /// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation. /// /// Specify for step. /// Specify a collection of that indicates the input values. /// Specify a collection of that indicates the output values. public void EvalStep( RunOptions options, IReadOnlyCollection inputValues, IReadOnlyCollection outputValues) { if (!_evalOutputCount.Equals(outputValues.Count)) { throw new ArgumentException($"Length of {nameof(outputValues)} ({outputValues.Count}) must match that of train model ({_trainOutputCount})."); } IntPtr[] inputValuesArray = GetOrtValuesHandles(inputValues, true); IntPtr[] outputValuesArray = GetOrtValuesHandles(outputValues, false); /* pointers to Pre-allocated OrtValue instances */ NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtTrainStep(_nativeHandle, options.Handle, (UIntPtr)inputValues.Count, inputValuesArray, (UIntPtr)outputValues.Count, outputValuesArray)); } /// /// Sets a constant learning rate for the session. LR must be controlled by either this method /// or by registering a LR scheduler. /// public void SetLearningRate(float learningRate) { if (_scheduler != LRScheduler.None && _scheduler != LRScheduler.Constant) { throw new InvalidOperationException("Cannot set constant LR while using LR scheduler."); } NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSetLearningRate(_nativeHandle, learningRate)); _scheduler = LRScheduler.Constant; } /// /// Gets the current learning rate for the session. /// public float GetLearningRate() { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetLearningRate(_nativeHandle, out float lr)); return lr; } /// /// Registers a linear learning rate scheduler for the session. LR must be controlled by either /// the SetLearningRate method or by registering a LR scheduler. /// Number of warmup steps /// Number of total steps /// Initial learning rate /// public void RegisterLinearLRScheduler(long warmupStepCount, long totalStepCount, float initialLearningRate) { if (_scheduler != LRScheduler.None && _scheduler != LRScheduler.Constant) { throw new InvalidOperationException("Cannot set LR scheduler while using constant LR."); } NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtRegisterLinearLRScheduler(_nativeHandle, warmupStepCount,totalStepCount, initialLearningRate)); _scheduler = LRScheduler.Linear; } /// /// Runs a LR scheduler step. There must be a valid LR scheduler registered for the training session. /// public void SchedulerStep() { if (_scheduler == LRScheduler.Constant || _scheduler == LRScheduler.None) { throw new InvalidOperationException("Cannot take scheduler step without registering a valid LR scheduler."); } NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSchedulerStep(_nativeHandle)); } /// /// Runs an optimizer step on the loaded model for the given inputs. The optimizer graph must be passed while TrainingSession creation. /// public void OptimizerStep() { OptimizerStep(_builtInRunOptions); } /// /// Runs an eval step on the loaded model for the given inputs. The eval graph must be passed while TrainingSession creation. /// /// Specify for step. /// Specify a collection of that indicates the output values. public void OptimizerStep(RunOptions options) { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtOptimizerStep(_nativeHandle, options.Handle)); } /// /// Saves a checkpoint to path. It can be loaded into /// /// Specify path for saving the checkpoint. /// SFlag indicating whether to save optimizer state or not. public void SaveCheckpoint(string path, bool saveOptimizerState = false) { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtSaveCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(path), _nativeHandle, saveOptimizerState)); } #endregion #region private methods private void Init(SessionOptions sessOptions, CheckpointState state, byte[] trainModelPath, byte[] evalModelPath, byte[] optimizerModelPath) { if (!NativeTrainingMethods.TrainingEnabled()) { throw new InvalidOperationException("Training is disabled in the current build."); } var options = sessOptions; if (sessOptions == null) { _builtInSessionOptions = new SessionOptions(); options = _builtInSessionOptions; } var envHandle = OrtEnv.Handle; try { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtCreateTrainingSession(envHandle, options.Handle, state.Handle, trainModelPath, evalModelPath, optimizerModelPath, out _nativeHandle)); UIntPtr outputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelOutputCount(_nativeHandle, out outputCount)); _trainOutputCount = outputCount.ToUInt64(); // get all the output names and metadata _trainOutputNames = new List(); for (ulong i = 0; i < _trainOutputCount; i++) { _trainOutputNames.Add(GetOutputName(i, true)); } if (evalModelPath != null) { outputCount = UIntPtr.Zero; NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelOutputCount(_nativeHandle, out outputCount)); _evalOutputCount = outputCount.ToUInt64(); _evalOutputNames = new List(); for (ulong i = 0; i < _evalOutputCount; i++) { _evalOutputNames.Add(GetOutputName(i, false)); } } _builtInRunOptions = new RunOptions(); // create a default built-in run option, and avoid creating a new one every run() call } catch (Exception) { CleanupHelper(true); throw; } } private string GetOutputName(ulong index, bool training) { var allocator = OrtAllocator.DefaultInstance; IntPtr nameHandle; string str = null; if (training) { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetTrainingModelOutputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out nameHandle)); } else { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtGetEvalModelOutputName( _nativeHandle, (UIntPtr)index, allocator.Pointer, out nameHandle)); } using (var ortAllocation = new OrtMemoryAllocation(allocator, nameHandle, 0)) { str = NativeOnnxValueHelper.StringFromNativeUtf8(nameHandle); } return str; } private IntPtr[] GetOrtValuesHandles(IReadOnlyCollection values, bool input) { var valuesArray = new IntPtr[values.Count]; for (int index = 0; index < values.Count; ++index) { var v = values.ElementAt(index); if (!input && v.ElementType == Tensors.TensorElementType.String) { throw new NotSupportedException("Using string type FixedBufferOnnxValue in outputs is not supported."); } valuesArray[index] = v.Value.Handle; } return valuesArray; } /// /// Other classes access /// internal IntPtr Handle { get { return _nativeHandle; } } #endregion #region IDisposable /// /// Finalizer. /// ~TrainingSession() { Dispose(false); } /// /// IDisposable implementation /// public void Dispose() { Dispose(true); GC.SuppressFinalize(this); } /// /// IDisposable implementation /// /// true if invoked from Dispose() method protected virtual void Dispose(bool disposing) { if (_disposed) { return; } CleanupHelper(disposing); _disposed = true; } private void CleanupHelper(bool disposing) { if (disposing) { if (_builtInRunOptions != null) { _builtInRunOptions.Dispose(); _builtInRunOptions = null; } if (_builtInSessionOptions != null) { _builtInSessionOptions.Dispose(); _builtInSessionOptions = null; } } // cleanup unmanaged resources if (_nativeHandle != IntPtr.Zero) { NativeTrainingMethods.OrtReleaseTrainingSession(_nativeHandle); _nativeHandle = IntPtr.Zero; } } #endregion } #endif }