binding.ts 2.11 KB
Newer Older
gaoqiong's avatar
gaoqiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

// eslint-disable-next-line @typescript-eslint/no-unused-vars
import type {InferenceSession} from 'onnxruntime-common';
import {NativeModules} from 'react-native';

/**
 * model loading information
 */
interface ModelLoadInfo {
  /**
   * Key for an instance of InferenceSession, which is passed to run() function as parameter.
   */
  readonly key: string;

  /**
   * Get input names of the loaded model.
   */
  readonly inputNames: string[];

  /**
   * Get output names of the loaded model.
   */
  readonly outputNames: string[];
}

/**
 * Tensor type for react native, which doesn't allow ArrayBuffer, so data will be encoded as Base64 string.
 */
interface EncodedTensor {
  /**
   * the dimensions of the tensor.
   */
  readonly dims: readonly number[];
  /**
   * the data type of the tensor.
   */
  readonly type: string;
  /**
   * the Base64 encoded string of the buffer data of the tensor.
   * if data is string array, it won't be encoded as Base64 string.
   */
  readonly data: string|string[];
}

/**
 * Binding exports a simple synchronized inference session object wrap.
 */
export declare namespace Binding {
  type ModelLoadInfoType = ModelLoadInfo;
  type EncodedTensorType = EncodedTensor;

  type SessionOptions = InferenceSession.SessionOptions;
  type RunOptions = InferenceSession.RunOptions;

  type FeedsType = {[name: string]: EncodedTensor};

  // SessionHanlder FetchesType is different from native module's one.
  // It's because Java API doesn't support preallocated output values.
  type FetchesType = string[];

  type ReturnType = {[name: string]: EncodedTensor};

  interface InferenceSession {
    loadModel(modelPath: string, options: SessionOptions): Promise<ModelLoadInfoType>;
    loadModelFromBase64EncodedBuffer?(buffer: string, options: SessionOptions): Promise<ModelLoadInfoType>;
    run(key: string, feeds: FeedsType, fetches: FetchesType, options: RunOptions): Promise<ReturnType>;
  }
}

// export native binding
const {Onnxruntime} = NativeModules;
export const binding = Onnxruntime as Binding.InferenceSession;