// Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. using System; using System.Runtime.InteropServices; namespace Microsoft.ML.OnnxRuntime { #if __ENABLE_TRAINING_ON_DEVICE__ /// /// Holds the Checkpoint State as generated/consumed by on-device training APIs /// public class CheckpointState : SafeHandle { internal IntPtr Handle { get { return handle; } } /// /// Creates CheckpointState by loading state from path. /// absolute path to checkpoint file. /// public CheckpointState(string checkpointPath) : base(IntPtr.Zero, true) { if (NativeTrainingMethods.TrainingEnabled()) { var envHandle = OrtEnv.Handle; // just so it is initialized LoadCheckpoint(checkpointPath); } else { throw new InvalidOperationException("Training is disabled in the current build"); } } /// /// Overrides SafeHandle.IsInvalid /// /// returns true if handle is equal to Zero public override bool IsInvalid { get { return handle == IntPtr.Zero; } } /// /// Loads Checkpoint state from path /// /// absolute path to checkpoint private void LoadCheckpoint(string checkpointPath) { NativeApiStatus.VerifySuccess(NativeTrainingMethods.OrtLoadCheckpoint(NativeOnnxValueHelper.GetPlatformSerializedString(checkpointPath), out handle)); } #region SafeHandle /// /// Overrides SafeHandle.ReleaseHandle() to properly dispose of /// the native instance of CheckpointState /// /// always returns true protected override bool ReleaseHandle() { NativeTrainingMethods.OrtReleaseCheckpointState(handle); handle = IntPtr.Zero; return true; } #endregion } #endif }