README.md 3.59 KB
Newer Older
Nikhila Ravi's avatar
Nikhila Ravi committed
1
2
## Segment Anything Simple Web demo

Nikhila Ravi's avatar
Nikhila Ravi committed
3
This **front-end only** React based web demo shows how to load a fixed image and corresponding `.npy` file of the SAM image embedding, and run the SAM ONNX model in the browser using Web Assembly with mulithreading enabled by `SharedArrayBuffer`, Web Worker, and SIMD128.
Nikhila Ravi's avatar
Nikhila Ravi committed
4
5
6
7
8

<img src="https://github.com/facebookresearch/segment-anything/raw/main/assets/minidemo.gif" width="500"/>

## Run the app

Nikhila Ravi's avatar
Nikhila Ravi committed
9
10
11
12
13
14
15
16
Install Yarn

```
npm install --g yarn
```

Build and run:

Nikhila Ravi's avatar
Nikhila Ravi committed
17
18
19
20
21
22
23
24
```
yarn && yarn start
```

Navigate to [`http://localhost:8081/`](http://localhost:8081/)

Move your cursor around to see the mask prediction update in real time.

25
## Export the image embedding
Nikhila Ravi's avatar
Nikhila Ravi committed
26
27
28

In the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb) upload the image of your choice and generate and save corresponding embedding.

Nikhila Ravi's avatar
Nikhila Ravi committed
29
Initialize the predictor:
Nikhila Ravi's avatar
Nikhila Ravi committed
30
31
32
33
34
35
36
37
38

```python
checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint=checkpoint)
sam.to(device='cuda')
predictor = SamPredictor(sam)
```

Nikhila Ravi's avatar
Nikhila Ravi committed
39
Set the new image and export the embedding:
Nikhila Ravi's avatar
Nikhila Ravi committed
40
41
42
43
44
45
46
47

```
image = cv2.imread('src/assets/dogs.jpg')
predictor.set_image(image)
image_embedding = predictor.get_image_embedding().cpu().numpy()
np.save("dogs_embedding.npy", image_embedding)
```

Nikhila Ravi's avatar
Nikhila Ravi committed
48
Save the new image and embedding in `src/assets/data`.
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75

## Export the ONNX model

You also need to export the quantized ONNX model from the [ONNX Model Example notebook](https://github.com/facebookresearch/segment-anything/blob/main/notebooks/onnx_model_example.ipynb).

Run the cell in the notebook which saves the `sam_onnx_quantized_example.onnx` file, download it and copy it to the path `/model/sam_onnx_quantized_example.onnx`.

Here is a snippet of the export/quantization code:

```
onnx_model_path = "sam_onnx_example.onnx"
onnx_model_quantized_path = "sam_onnx_quantized_example.onnx"
quantize_dynamic(
    model_input=onnx_model_path,
    model_output=onnx_model_quantized_path,
    optimize_model=True,
    per_channel=False,
    reduce_range=False,
    weight_type=QuantType.QUInt8,
)
```

**NOTE: if you change the ONNX model by using a new checkpoint you need to also re-export the embedding.**

## Update the image, embedding, model in the app

Update the following file paths at the top of`App.tsx`:
Nikhila Ravi's avatar
Nikhila Ravi committed
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126

```py
const IMAGE_PATH = "/assets/data/dogs.jpg";
const IMAGE_EMBEDDING = "/assets/data/dogs_embedding.npy";
const MODEL_DIR = "/model/sam_onnx_quantized_example.onnx";
```

## ONNX multithreading with SharedArrayBuffer

To use multithreading, the appropriate headers need to be set to create a cross origin isolation state which will enable use of `SharedArrayBuffer` (see this [blog post](https://cloudblogs.microsoft.com/opensource/2021/09/02/onnx-runtime-web-running-your-machine-learning-model-in-browser/) for more details)

The headers below are set in `configs/webpack/dev.js`:

```js
headers: {
    "Cross-Origin-Opener-Policy": "same-origin",
    "Cross-Origin-Embedder-Policy": "credentialless",
}
```

## Structure of the app

**`App.tsx`**

- Initializes ONNX model
- Loads image embedding and image
- Runs the ONNX model based on input prompts

**`Stage.tsx`**

- Handles mouse move interaction to update the ONNX model prompt

**`Tool.tsx`**

- Renders the image and the mask prediction

**`helpers/maskUtils.tsx`**

- Conversion of ONNX model output from array to an HTMLImageElement

**`helpers/onnxModelAPI.tsx`**

- Formats the inputs for the ONNX model

**`helpers/scaleHelper.tsx`**

- Handles image scaling logic for SAM (longest size 1024)

**`hooks/`**

- Handle shared state for the app