App.tsx 3.37 KB
Newer Older
gaoqiong's avatar
gaoqiong committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.

import * as React from 'react';
import{Image, Text, TextInput, View} from 'react-native';
// onnxruntime-react-native package is installed when bootstraping
// eslint-disable-next-line import/no-extraneous-dependencies
import{InferenceSession, Tensor} from 'onnxruntime-react-native';
import MNIST, {MNISTInput, MNISTOutput, MNISTResult, } from './mnist-data-handler';
import{Buffer} from 'buffer';

interface State {
session:
  InferenceSession | null;
output:
  string | null;
imagePath:
  string | null;
}

// eslint-disable-next-line @typescript-eslint/ban-types
export default class App extends React.PureComponent<{}, State> {
  // eslint-disable-next-line @typescript-eslint/ban-types
  constructor(props : {} | Readonly<{}>) {
    super(props);

    this.state = {
      session : null,
      output : null,
      imagePath : null,
    };
  }

  // Load a model when an app is loading
  async componentDidMount() : Promise<void> {
    if (!this.state.session) {
      try {
        const imagePath = await MNIST.getImagePath();
        this.setState({imagePath});

        const modelPath = await MNIST.getLocalModelPath();
        const session : InferenceSession = await InferenceSession.create(modelPath);
        this.setState({session});

        void this.infer();
      } catch (err) {
        console.log(err.message);
      }
    }
  }

  // Run a model with a given image
  infer = async() : Promise<void> => {
    try {
      const options : InferenceSession.RunOptions = {};

      const mnistInput : MNISTInput = await MNIST.preprocess(this.state.imagePath !);
      const input : {[name:string] : Tensor} = {};
      for (const key in mnistInput) {
        if (Object.hasOwnProperty.call(mnistInput, key)) {
          const buffer = Buffer.from(mnistInput[key].data, 'base64');
          const tensorData =
              new Float32Array(buffer.buffer, buffer.byteOffset, buffer.length / Float32Array.BYTES_PER_ELEMENT);
          input[key] = new Tensor(mnistInput[key].type as keyof Tensor.DataTypeMap, tensorData, mnistInput[key].dims);
        }
      }

      const output : InferenceSession.ReturnType =
          await this.state.session !.run(input, this.state.session !.outputNames, options);

      const mnistOutput : MNISTOutput = {};
      for (const key in output) {
        if (Object.hasOwnProperty.call(output, key)) {
          const buffer = (output[key].data as Float32Array).buffer;
          const tensorData = {
            data : Buffer.from(buffer, 0, buffer.byteLength).toString('base64'),
          };
          mnistOutput[key] = tensorData;
        }
      }
      const result : MNISTResult = await MNIST.postprocess(mnistOutput);

      this.setState({
        output : result.result
      });
    } catch (err) {
      console.log(err.message);
    }
  };

  render() : JSX.Element {
    const {output, imagePath} = this.state;

    return (
      <View>
        <Text>{'\n'}</Text>
        {imagePath && (
          <Image
            source={{
              uri: imagePath,
            }}
            style={{
              height: 200,
              width: 200,
              resizeMode: 'stretch',
            }}
          />
        )}
        {output && (
          <TextInput accessibilityLabel='output'>
            Result: {output}
          </TextInput>
        )}
      </View>
    );
  }
}