# SAM ## 论文 Segment Anything - https://arxiv.org/abs/2304.02643 ## 模型结构 ![](./assets/model_diagram.png) 如图,该模型的网络结构主要分三个部分:Image encoder、Prompt encoder和Lightweight mask decoder。 ### Image encoder 使用ViT-H/16网络处理高分辨率输入,输出是输入图像的16倍缩小的嵌入(64×64)。通道维度降低至256,通过1×1和3×3卷积层。 ### Prompt encoder 映射到256维向量嵌入,包括位置编码和前景/背景信息。框由左上角和右下角嵌入对表示。文本编码器使用CLIP。 ### Lightweight mask decoder ![](./assets/mask_decoder.PNG) 图像嵌入通过两个转置卷积层放大4倍,经过MLP输出掩码。Transformer使用256嵌入维度,64×64图像嵌入的交叉注视层使用128通道维度。 ## 算法原理 ![](./assets/algorithm.png) SAM分为图像编码器和快速提示编码器/掩码解码器,可以重用相同的image embedding图像嵌入(并摊销其成本)与不同的提示。给定image embedding图像嵌入,提示编码器和掩码解码器可以在web浏览器中预测掩码。为了使SAM实现模糊感知,设计它来预测单个提示的多个掩码,从而使SAM能够自然地处理模糊性。 ## 环境配置 ### Docker(方法一) 从[光源](https://www.sourcefind.cn/#/service-list)拉取镜像 ``` docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 docker run -it --network=host --name=SAM_pytorch -v /opt/hyhal:/opt/hyhal:ro --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16G --group-add video --cap-add=SYS_PTRACE image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10 /bin/bash ``` 安装其他依赖: ``` pip install opencv-python pycocotools matplotlib onnxruntime onnx ``` ### Dockerfile(方法二) ``` cd /path/to/dockerfile docker build --no-cache -t sam_pytorch:latest . docker run -it --network=host --name=SAM_pytorch -v /opt/hyhal:/opt/hyhal:ro --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=16g --group-add video --cap-add=SYS_PTRACE -it SAM_pytorch:latest bash ``` ### Anaconda(方法三) 1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装: https://developer.hpccube.com/tool/ ``` DTK软件栈:dtk24.04.1 python:python3.10 torch:2.1.0 torchvision:0.16.0 ``` Tips:以上dtk软件栈、python、torch等DCU相关工具版本需要严格一一对应 2、安装其他依赖 直接使用pip install的方式安装 ``` pip install opencv-python pycocotools matplotlib onnxruntime onnx pip install git+https://github.com/facebookresearch/segment-anything.git ``` 或下载后本地安装 ``` pip install opencv-python pycocotools matplotlib onnxruntime onnx git clone git@github.com:facebookresearch/segment-anything.git cd segment-anything pip install -e . ``` ## 数据集 在本测试中训练部分数据集使用COCO2017数据集。 - 数据集快速下载中心: - [SCNet AIDatasets](http://113.200.138.88:18080/aidatasets) - 数据集快速通道下载地址: - [数据集快速下载地址](http://113.200.138.88:18080/aidatasets/coco2017) - 官方下载地址 - [训练数据](http://images.cocodataset.org/zips/train2017.zip) - [验证数据](http://images.cocodataset.org/zips/val2017.zip) - [测试数据](http://images.cocodataset.org/zips/test2017.zip) - [标签数据](https://github.com/ultralytics/yolov5/releases/download/v1.0/coco2017labels.zip) 数据集的目录结构如下: ``` ├── images │ ├── train2017 │ ├── val2017 │ ├── test2017 ├── labels │ ├── train2017 │ ├── val2017 ├── annotations │ ├── instances_val2017.json ├── LICENSE ├── README.txt ├── test-dev2017.txt ├── train2017.txt ├── val2017.txt ``` 推理数据集名称:SA-1B Dataset 完整数据集可在[这里](https://ai.facebook.com/datasets/segment-anything-downloads/)进行下载 项目中用于试验训练的迷你数据集结构如下 ``` ── notebooks │   ├── images │ │   ├── dog.jpg │ │   ├── groceries.jpg │ │ └── trunk.jpg ``` ## 训练 ### 微调 官网提供了生成掩码的预训练权重和生成掩码的脚本,没有提供训练脚本,但可使用第三方提供的示例脚本微调 如果您有兴趣,参考[这里](https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam/train.py). ### 单机多卡 预训练模型在[这里](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)下载 ``` git clone https://github.com/luca-medeiros/lightning-sam.git cd lightning-sam 修改pyproject.toml文件中的第六行为documentation = "https://this/needs/to/be/something/otherwise/poetry/complains" pip install . pip install tensorboardX==2.6.2.2 cd lightning_sam 根据实际情况在config.py中修改相关参数:卡数、数据集路径、checkpoint模型路径 python train.py ``` pip install . 过程中会顶掉DCU版本的pytorch、torchvision、triton,需要到[开发者社区](https://cancon.hpccube.com:65024/4/main/pytorch)下载DCU版本对应包 ## 推理 ``` python scripts/amg.py --checkpoint --model-type --input --output ``` 注:checkpoint预训练模型在[这里](https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth)下载 指令中: 代表选择权重的路径 代表不同的模型,可选择'vit_h'、'vit_l'、 'vit_b' 代表输入图片或者文件夹的路径 代表分割结果保存路径 ## result ![](./ouputs/000000524456/0.png) 掩码生成的部分结果在同级目录outputs中可以查看,结果示例如上图,官方提供demo可在[这里](https://segment-anything.com/demo)试用 ## 精度 无 ## 应用场景 ### 算法类别 图像分割 ### 热点应用行业 能源,医疗,网安 ## 源码仓库及问题反馈 https://developer.hpccube.com/codes/modelzoo/sam_pytorch ## 参考资料 https://github.com/facebookresearch/segment-anything https://github.com/luca-medeiros/lightning-sam/blob/main/lightning_sam(第三方SAM微调)