Commit 50dd7d3e authored by dengjb's avatar dengjb
Browse files

update

parents
Pipeline #3040 canceled with stages
MIT License
Copyright (c) 2020 Gongfan Fang
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# MiniMax-M2
## 论文
暂无
## 模型简介
MiniMax-M2 重新定义了代理的效率。它是一个紧凑、快速且成本效益高的 MoE 模型(总参数量为 2300 DeeplabV3plus 是一种先进的用于语义分割任务的深度学习模型。DeepLabV3plus模型采用了编码器-解码器(Encoder-Decoder)结构,通过编码器提取图像特征,再通过解码器将这些特征映射回原始图像尺寸,实现像素级的分类。具体来说,模型的主干网络(论文中对ResNet101或Xception做了实验)负责特征提取,特征提取分为高层语义提取和底层的语义提取两个部分。然后,模型会利用空洞卷积(Dilated Convolution)技术,构建了ASPP(Atrous Spatial Pyramid Pooling)模块,提高模型在不同尺度特征提取上的能力。最后,通过解码器恢复图像的细节信息,得到最终的分割结果。总体流程如下:
![alt text](image.png)
## 环境依赖
| 软件 | 版本 |
| :------: | :------: |
| DTK | 25.04.2 |
| python | 3.10.12 |
| transformers | 4.57.1 |
| vllm | 0.11.0+das.opt1.alpha.8e22ded.dtk25042 |
| torch | 2.5.1+das.opt1.dtk25042 |
| triton | 3.1+das.opt1.3c5d12d.dtk25041 |
| flash_attn | 2.6.1+das.opt1.dtk2504 |
| flash_mla | 1.0.0+das.opt1.dtk25042 |
当前仅支持镜像:
- 挂载地址`-v`根据实际模型情况修改
```bash
docker run -it --shm-size 60g --network=host --name minimax_m2 --privileged --device=/dev/kfd --device=/dev/dri --device=/dev/mkfd --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root -v /opt/hyhal/:/opt/hyhal/:ro -v /path/your_code_path/:/path/your_code_path/ image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.5.1-ubuntu22.04-dtk25.04.2-py3.10 bash
```
更多镜像可前往[光源](https://sourcefind.cn/#/service-list)下载使用。
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.sourcefind.cn/tool/)开发者社区下载安装。
## 数据集
### Pascal VOC数据集
你可以通过运行train.py 使用 "--download" 选项下载并解压 pascal voc dataset. 默认地址为: './datasets/data':
```
/datasets
/data
/VOCdevkit
/VOC2012
/SegmentationClass
/JPEGImages
...
...
/VOCtrainval_11-May-2012.tar
...
```
### 下载cityscapes并解压到 'datasets/data/cityscapes'
```
/datasets
/data
/cityscapes
/gtFine
/leftImg8bit
```
## 训练
### 1. 可用模型架构
| DeepLabV3 | DeepLabV3+ |
| :---: | :---: |
|deeplabv3_resnet50|deeplabv3plus_resnet50|
|deeplabv3_resnet101|deeplabv3plus_resnet101|
|deeplabv3_mobilenet|deeplabv3plus_mobilenet ||
|deeplabv3_hrnetv2_48 | deeplabv3plus_hrnetv2_48 |
|deeplabv3_hrnetv2_32 | deeplabv3plus_hrnetv2_32 |
|deeplabv3_xception | deeplabv3plus_xception |
请参考 [network/modeling.py](https://github.com/VainF/DeepLabV3Plus-Pytorch/blob/master/network/modeling.py) 适用于所有模型结构。
模型下载: [Dropbox](https://www.dropbox.com/sh/w3z9z8lqpi8b2w7/AAB0vkl4F5vy6HdIhmRCTKHSa?dl=0), [腾讯微盘](https://share.weiyun.com/qqx78Pv5)
### 2。基于数据集训练
#### 1. Pascal VOC进行训练
```bash
python main.py --model deeplabv3plus_mobilenet --enable_vis --vis_port 28333 --gpu_id 0 --year 2012_aug --crop_val --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16
```
#### 2. Cityscapes进行训练
```bash
python main.py --model deeplabv3plus_mobilenet --dataset cityscapes --enable_vis --vis_port 28333 --gpu_id 0 --lr 0.1 --crop_size 768 --batch_size 16 --output_stride 16 --data_root ./datasets/data/cityscapes
```
## 推理
推理测试
```bash
python main.py --model deeplabv3plus_mobilenet --enable_vis --vis_port 28333 --gpu_id 0 --year 2012_aug --crop_val --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16 --ckpt checkpoints/best_deeplabv3plus_mobilenet_voc_os16.pth --test_only --save_val_results
```
### 精度
DCU与GPU精度一致,推理框架:pytorch。
## 预训练权重
模型名称 | Batch Size | 权重大小 | train/val OS|DCU型号 | mIoU | 权重下载-Dropbox | 权重下载-微云 |
| :-------- | :-------------: | :----: | :----: | :-----------: | :--------: | :--------: | :----: |
| DeepLabV3-MobileNet | 16 | 6.0G | 16/16 | K100AI | 0.701 | [Download](https://www.dropbox.com/s/uhksxwfcim3nkpo/best_deeplabv3_mobilenet_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/A4ubD1DD) |
| DeepLabV3-ResNet50 | 16 | 51.4G | 16/16 | K100AI | 0.769 | [Download](https://www.dropbox.com/s/3eag5ojccwiexkq/best_deeplabv3_resnet50_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/33eLjnVL) |
| DeepLabV3-ResNet101 | 16 | 72.1G | 16/16 | K100AI | 0.773 | [Download](https://www.dropbox.com/s/vtenndnsrnh4068/best_deeplabv3_resnet101_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/iCkzATAw) |
| DeepLabV3Plus-MobileNet | 16 | 17.0G | 16/16 | K100AI | 0.711 | [Download](https://www.dropbox.com/s/0idrhwz6opaj7q4/best_deeplabv3plus_mobilenet_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/djX6MDwM) |
| DeepLabV3Plus-ResNet50 | 16 | 62.7G | 16/16 | K100AI | 0.772 | [Download](https://www.dropbox.com/s/dgxyd3jkyz24voa/best_deeplabv3plus_resnet50_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/uTM4i2jG) |
| DeepLabV3Plus-ResNet101 | 16 | 83.4G | 16/16 | K100AI | 0.783 | [Download](https://www.dropbox.com/s/bm3hxe7wmakaqc5/best_deeplabv3plus_resnet101_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/UNPZr3dk) |
## 源码仓库及问题反馈
- https://developer.sourcefind.cn/codes/modelzoo/deeplabv3-plus_pytorch
## 参考资料
- https://developer.sourcefind.cn/codes/modelzoo/deeplabv3-plus_pytorch
# DeepLabv3Plus-Pytorch
Pretrained DeepLabv3, DeepLabv3+ for Pascal VOC & Cityscapes.
## Quick Start
### 1. Available Architectures
| DeepLabV3 | DeepLabV3+ |
| :---: | :---: |
|deeplabv3_resnet50|deeplabv3plus_resnet50|
|deeplabv3_resnet101|deeplabv3plus_resnet101|
|deeplabv3_mobilenet|deeplabv3plus_mobilenet ||
|deeplabv3_hrnetv2_48 | deeplabv3plus_hrnetv2_48 |
|deeplabv3_hrnetv2_32 | deeplabv3plus_hrnetv2_32 |
|deeplabv3_xception | deeplabv3plus_xception |
please refer to [network/modeling.py](https://github.com/VainF/DeepLabV3Plus-Pytorch/blob/master/network/modeling.py) for all model entries.
Download pretrained models: [Dropbox](https://www.dropbox.com/sh/w3z9z8lqpi8b2w7/AAB0vkl4F5vy6HdIhmRCTKHSa?dl=0), [Tencent Weiyun](https://share.weiyun.com/qqx78Pv5)
Note: The HRNet backbone was contributed by @timothylimyl. A pre-trained backbone is available at [google drive](https://drive.google.com/file/d/1NxCK7Zgn5PmeS7W1jYLt5J9E0RRZ2oyF/view?usp=sharing).
### 2. Load the pretrained model:
```python
model = network.modeling.__dict__[MODEL_NAME](num_classes=NUM_CLASSES, output_stride=OUTPUT_SRTIDE)
model.load_state_dict( torch.load( PATH_TO_PTH )['model_state'] )
```
### 3. Visualize segmentation outputs:
```python
outputs = model(images)
preds = outputs.max(1)[1].detach().cpu().numpy()
colorized_preds = val_dst.decode_target(preds).astype('uint8') # To RGB images, (N, H, W, 3), ranged 0~255, numpy array
# Do whatever you like here with the colorized segmentation maps
colorized_preds = Image.fromarray(colorized_preds[0]) # to PIL Image
```
### 4. Atrous Separable Convolution
**Note**: All pre-trained models in this repo were trained without atrous separable convolution.
Atrous Separable Convolution is supported in this repo. We provide a simple tool ``network.convert_to_separable_conv`` to convert ``nn.Conv2d`` to ``AtrousSeparableConvolution``. **Please run main.py with '--separable_conv' if it is required**. See 'main.py' and 'network/_deeplab.py' for more details.
### 5. Prediction
Single image:
```bash
python predict.py --input datasets/data/cityscapes/leftImg8bit/train/bremen/bremen_000000_000019_leftImg8bit.png --dataset cityscapes --model deeplabv3plus_mobilenet --ckpt checkpoints/best_deeplabv3plus_mobilenet_cityscapes_os16.pth --save_val_results_to test_results
```
Image folder:
```bash
python predict.py --input datasets/data/cityscapes/leftImg8bit/train/bremen --dataset cityscapes --model deeplabv3plus_mobilenet --ckpt checkpoints/best_deeplabv3plus_mobilenet_cityscapes_os16.pth --save_val_results_to test_results
```
### 6. New backbones
Please refer to [this commit (Xception)](https://github.com/VainF/DeepLabV3Plus-Pytorch/commit/c4b51e435e32b0deba5fc7c8ff106293df90590d) for more details about how to add new backbones.
### 7. New datasets
You can train deeplab models on your own datasets. Your ``torch.utils.data.Dataset`` should provide a decoding method that transforms your predictions to colorized images, just like the [VOC Dataset](https://github.com/VainF/DeepLabV3Plus-Pytorch/blob/bfe01d5fca5b6bb648e162d522eed1a9a8b324cb/datasets/voc.py#L156):
```python
class MyDataset(data.Dataset):
...
@classmethod
def decode_target(cls, mask):
"""decode semantic mask to RGB image"""
return cls.cmap[mask]
```
## Results
### 1. Performance on Pascal VOC2012 Aug (21 classes, 513 x 513)
Training: 513x513 random crop
validation: 513x513 center crop
| Model | Batch Size | FLOPs | train/val OS | mIoU | Dropbox | Tencent Weiyun |
| :-------- | :-------------: | :----: | :-----------: | :--------: | :--------: | :----: |
| DeepLabV3-MobileNet | 16 | 6.0G | 16/16 | 0.701 | [Download](https://www.dropbox.com/s/uhksxwfcim3nkpo/best_deeplabv3_mobilenet_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/A4ubD1DD) |
| DeepLabV3-ResNet50 | 16 | 51.4G | 16/16 | 0.769 | [Download](https://www.dropbox.com/s/3eag5ojccwiexkq/best_deeplabv3_resnet50_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/33eLjnVL) |
| DeepLabV3-ResNet101 | 16 | 72.1G | 16/16 | 0.773 | [Download](https://www.dropbox.com/s/vtenndnsrnh4068/best_deeplabv3_resnet101_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/iCkzATAw) |
| DeepLabV3Plus-MobileNet | 16 | 17.0G | 16/16 | 0.711 | [Download](https://www.dropbox.com/s/0idrhwz6opaj7q4/best_deeplabv3plus_mobilenet_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/djX6MDwM) |
| DeepLabV3Plus-ResNet50 | 16 | 62.7G | 16/16 | 0.772 | [Download](https://www.dropbox.com/s/dgxyd3jkyz24voa/best_deeplabv3plus_resnet50_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/uTM4i2jG) |
| DeepLabV3Plus-ResNet101 | 16 | 83.4G | 16/16 | 0.783 | [Download](https://www.dropbox.com/s/bm3hxe7wmakaqc5/best_deeplabv3plus_resnet101_voc_os16.pth?dl=0) | [Download](https://share.weiyun.com/UNPZr3dk) |
### 2. Performance on Cityscapes (19 classes, 1024 x 2048)
Training: 768x768 random crop
validation: 1024x2048
| Model | Batch Size | FLOPs | train/val OS | mIoU | Dropbox | Tencent Weiyun |
| :-------- | :-------------: | :----: | :-----------: | :--------: | :--------: | :----: |
| DeepLabV3Plus-MobileNet | 16 | 135G | 16/16 | 0.721 | [Download](https://www.dropbox.com/s/753ojyvsh3vdjol/best_deeplabv3plus_mobilenet_cityscapes_os16.pth?dl=0) | [Download](https://share.weiyun.com/aSKjdpbL)
| DeepLabV3Plus-ResNet101 | 16 | N/A | 16/16 | 0.762 | [Download](https://drive.google.com/file/d/1t7TC8mxQaFECt4jutdq_NMnWxdm6B-Nb/view?usp=sharing) | N/A |
#### Segmentation Results on Pascal VOC2012 (DeepLabv3Plus-MobileNet)
<div>
<img src="samples/1_image.png" width="20%">
<img src="samples/1_target.png" width="20%">
<img src="samples/1_pred.png" width="20%">
<img src="samples/1_overlay.png" width="20%">
</div>
<div>
<img src="samples/23_image.png" width="20%">
<img src="samples/23_target.png" width="20%">
<img src="samples/23_pred.png" width="20%">
<img src="samples/23_overlay.png" width="20%">
</div>
<div>
<img src="samples/114_image.png" width="20%">
<img src="samples/114_target.png" width="20%">
<img src="samples/114_pred.png" width="20%">
<img src="samples/114_overlay.png" width="20%">
</div>
#### Segmentation Results on Cityscapes (DeepLabv3Plus-MobileNet)
<div>
<img src="samples/city_1_target.png" width="45%">
<img src="samples/city_1_overlay.png" width="45%">
</div>
<div>
<img src="samples/city_6_target.png" width="45%">
<img src="samples/city_6_overlay.png" width="45%">
</div>
#### Visualization of training
![trainvis](samples/visdom-screenshoot.png)
## Pascal VOC
### 1. Requirements
```bash
pip install -r requirements.txt
```
### 2. Prepare Datasets
#### 2.1 Standard Pascal VOC
You can run train.py with "--download" option to download and extract pascal voc dataset. The defaut path is './datasets/data':
```
/datasets
/data
/VOCdevkit
/VOC2012
/SegmentationClass
/JPEGImages
...
...
/VOCtrainval_11-May-2012.tar
...
```
#### 2.2 Pascal VOC trainaug (Recommended!!)
See chapter 4 of [2]
The original dataset contains 1464 (train), 1449 (val), and 1456 (test) pixel-level annotated images. We augment the dataset by the extra annotations provided by [76], resulting in 10582 (trainaug) training images. The performance is measured in terms of pixel intersection-over-union averaged across the 21 classes (mIOU).
*./datasets/data/train_aug.txt* includes the file names of 10582 trainaug images (val images are excluded). Please to download their labels from [Dropbox](https://www.dropbox.com/s/oeu149j8qtbs1x0/SegmentationClassAug.zip?dl=0) or [Tencent Weiyun](https://share.weiyun.com/5NmJ6Rk). Those labels come from [DrSleep's repo](https://github.com/DrSleep/tensorflow-deeplab-resnet).
Extract trainaug labels (SegmentationClassAug) to the VOC2012 directory.
```
/datasets
/data
/VOCdevkit
/VOC2012
/SegmentationClass
/SegmentationClassAug # <= the trainaug labels
/JPEGImages
...
...
/VOCtrainval_11-May-2012.tar
...
```
### 3. Training on Pascal VOC2012 Aug
#### 3.1 Visualize training (Optional)
Start visdom sever for visualization. Please remove '--enable_vis' if visualization is not needed.
```bash
# Run visdom server on port 28333
visdom -port 28333
```
#### 3.2 Training with OS=16
Run main.py with *"--year 2012_aug"* to train your model on Pascal VOC2012 Aug. You can also parallel your training on 4 GPUs with '--gpu_id 0,1,2,3'
**Note: There is no SyncBN in this repo, so training with *multple GPUs and small batch size* may degrades the performance. See [PyTorch-Encoding](https://hangzhang.org/PyTorch-Encoding/tutorials/syncbn.html) for more details about SyncBN**
```bash
python main.py --model deeplabv3plus_mobilenet --enable_vis --vis_port 28333 --gpu_id 0 --year 2012_aug --crop_val --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16
```
#### 3.3 Continue training
Run main.py with '--continue_training' to restore the state_dict of optimizer and scheduler from YOUR_CKPT.
```bash
python main.py ... --ckpt YOUR_CKPT --continue_training
```
#### 3.4. Testing
Results will be saved at ./results.
```bash
python main.py --model deeplabv3plus_mobilenet --enable_vis --vis_port 28333 --gpu_id 0 --year 2012_aug --crop_val --lr 0.01 --crop_size 513 --batch_size 16 --output_stride 16 --ckpt checkpoints/best_deeplabv3plus_mobilenet_voc_os16.pth --test_only --save_val_results
```
## Cityscapes
### 1. Download cityscapes and extract it to 'datasets/data/cityscapes'
```
/datasets
/data
/cityscapes
/gtFine
/leftImg8bit
```
### 2. Train your model on Cityscapes
```bash
python main.py --model deeplabv3plus_mobilenet --dataset cityscapes --enable_vis --vis_port 28333 --gpu_id 0 --lr 0.1 --crop_size 768 --batch_size 16 --output_stride 16 --data_root ./datasets/data/cityscapes
```
## Reference
[1] [Rethinking Atrous Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1706.05587)
[2] [Encoder-Decoder with Atrous Separable Convolution for Semantic Image Segmentation](https://arxiv.org/abs/1802.02611)
from .voc import VOCSegmentation
from .cityscapes import Cityscapes
\ No newline at end of file
import json
import os
from collections import namedtuple
import torch
import torch.utils.data as data
from PIL import Image
import numpy as np
class Cityscapes(data.Dataset):
"""Cityscapes <http://www.cityscapes-dataset.com/> Dataset.
**Parameters:**
- **root** (string): Root directory of dataset where directory 'leftImg8bit' and 'gtFine' or 'gtCoarse' are located.
- **split** (string, optional): The image split to use, 'train', 'test' or 'val' if mode="gtFine" otherwise 'train', 'train_extra' or 'val'
- **mode** (string, optional): The quality mode to use, 'gtFine' or 'gtCoarse' or 'color'. Can also be a list to output a tuple with all specified target types.
- **transform** (callable, optional): A function/transform that takes in a PIL image and returns a transformed version. E.g, ``transforms.RandomCrop``
- **target_transform** (callable, optional): A function/transform that takes in the target and transforms it.
"""
# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass = namedtuple('CityscapesClass', ['name', 'id', 'train_id', 'category', 'category_id',
'has_instances', 'ignore_in_eval', 'color'])
classes = [
CityscapesClass('unlabeled', 0, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('ego vehicle', 1, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('rectification border', 2, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('out of roi', 3, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('static', 4, 255, 'void', 0, False, True, (0, 0, 0)),
CityscapesClass('dynamic', 5, 255, 'void', 0, False, True, (111, 74, 0)),
CityscapesClass('ground', 6, 255, 'void', 0, False, True, (81, 0, 81)),
CityscapesClass('road', 7, 0, 'flat', 1, False, False, (128, 64, 128)),
CityscapesClass('sidewalk', 8, 1, 'flat', 1, False, False, (244, 35, 232)),
CityscapesClass('parking', 9, 255, 'flat', 1, False, True, (250, 170, 160)),
CityscapesClass('rail track', 10, 255, 'flat', 1, False, True, (230, 150, 140)),
CityscapesClass('building', 11, 2, 'construction', 2, False, False, (70, 70, 70)),
CityscapesClass('wall', 12, 3, 'construction', 2, False, False, (102, 102, 156)),
CityscapesClass('fence', 13, 4, 'construction', 2, False, False, (190, 153, 153)),
CityscapesClass('guard rail', 14, 255, 'construction', 2, False, True, (180, 165, 180)),
CityscapesClass('bridge', 15, 255, 'construction', 2, False, True, (150, 100, 100)),
CityscapesClass('tunnel', 16, 255, 'construction', 2, False, True, (150, 120, 90)),
CityscapesClass('pole', 17, 5, 'object', 3, False, False, (153, 153, 153)),
CityscapesClass('polegroup', 18, 255, 'object', 3, False, True, (153, 153, 153)),
CityscapesClass('traffic light', 19, 6, 'object', 3, False, False, (250, 170, 30)),
CityscapesClass('traffic sign', 20, 7, 'object', 3, False, False, (220, 220, 0)),
CityscapesClass('vegetation', 21, 8, 'nature', 4, False, False, (107, 142, 35)),
CityscapesClass('terrain', 22, 9, 'nature', 4, False, False, (152, 251, 152)),
CityscapesClass('sky', 23, 10, 'sky', 5, False, False, (70, 130, 180)),
CityscapesClass('person', 24, 11, 'human', 6, True, False, (220, 20, 60)),
CityscapesClass('rider', 25, 12, 'human', 6, True, False, (255, 0, 0)),
CityscapesClass('car', 26, 13, 'vehicle', 7, True, False, (0, 0, 142)),
CityscapesClass('truck', 27, 14, 'vehicle', 7, True, False, (0, 0, 70)),
CityscapesClass('bus', 28, 15, 'vehicle', 7, True, False, (0, 60, 100)),
CityscapesClass('caravan', 29, 255, 'vehicle', 7, True, True, (0, 0, 90)),
CityscapesClass('trailer', 30, 255, 'vehicle', 7, True, True, (0, 0, 110)),
CityscapesClass('train', 31, 16, 'vehicle', 7, True, False, (0, 80, 100)),
CityscapesClass('motorcycle', 32, 17, 'vehicle', 7, True, False, (0, 0, 230)),
CityscapesClass('bicycle', 33, 18, 'vehicle', 7, True, False, (119, 11, 32)),
CityscapesClass('license plate', -1, 255, 'vehicle', 7, False, True, (0, 0, 142)),
]
train_id_to_color = [c.color for c in classes if (c.train_id != -1 and c.train_id != 255)]
train_id_to_color.append([0, 0, 0])
train_id_to_color = np.array(train_id_to_color)
id_to_train_id = np.array([c.train_id for c in classes])
#train_id_to_color = [(0, 0, 0), (128, 64, 128), (70, 70, 70), (153, 153, 153), (107, 142, 35),
# (70, 130, 180), (220, 20, 60), (0, 0, 142)]
#train_id_to_color = np.array(train_id_to_color)
#id_to_train_id = np.array([c.category_id for c in classes], dtype='uint8') - 1
def __init__(self, root, split='train', mode='fine', target_type='semantic', transform=None):
self.root = os.path.expanduser(root)
self.mode = 'gtFine'
self.target_type = target_type
self.images_dir = os.path.join(self.root, 'leftImg8bit', split)
self.targets_dir = os.path.join(self.root, self.mode, split)
self.transform = transform
self.split = split
self.images = []
self.targets = []
if split not in ['train', 'test', 'val']:
raise ValueError('Invalid split for mode! Please use split="train", split="test"'
' or split="val"')
if not os.path.isdir(self.images_dir) or not os.path.isdir(self.targets_dir):
raise RuntimeError('Dataset not found or incomplete. Please make sure all required folders for the'
' specified "split" and "mode" are inside the "root" directory')
for city in os.listdir(self.images_dir):
img_dir = os.path.join(self.images_dir, city)
target_dir = os.path.join(self.targets_dir, city)
for file_name in os.listdir(img_dir):
self.images.append(os.path.join(img_dir, file_name))
target_name = '{}_{}'.format(file_name.split('_leftImg8bit')[0],
self._get_target_suffix(self.mode, self.target_type))
self.targets.append(os.path.join(target_dir, target_name))
@classmethod
def encode_target(cls, target):
return cls.id_to_train_id[np.array(target)]
@classmethod
def decode_target(cls, target):
target[target == 255] = 19
#target = target.astype('uint8') + 1
return cls.train_id_to_color[target]
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.targets[index])
if self.transform:
image, target = self.transform(image, target)
target = self.encode_target(target)
return image, target
def __len__(self):
return len(self.images)
def _load_json(self, path):
with open(path, 'r') as file:
data = json.load(file)
return data
def _get_target_suffix(self, mode, target_type):
if target_type == 'instance':
return '{}_instanceIds.png'.format(mode)
elif target_type == 'semantic':
return '{}_labelIds.png'.format(mode)
elif target_type == 'color':
return '{}_color.png'.format(mode)
elif target_type == 'polygon':
return '{}_polygons.json'.format(mode)
elif target_type == 'depth':
return '{}_disparity.png'.format(mode)
\ No newline at end of file
This diff is collapsed.
import os
import os.path
import hashlib
import errno
from tqdm import tqdm
def gen_bar_updater(pbar):
def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
progress_bytes = count * block_size
pbar.update(progress_bytes - pbar.n)
return bar_update
def check_integrity(fpath, md5=None):
if md5 is None:
return True
if not os.path.isfile(fpath):
return False
md5o = hashlib.md5()
with open(fpath, 'rb') as f:
# read in 1MB chunks
for chunk in iter(lambda: f.read(1024 * 1024), b''):
md5o.update(chunk)
md5c = md5o.hexdigest()
if md5c != md5:
return False
return True
def makedir_exist_ok(dirpath):
"""
Python2 support for os.makedirs(.., exist_ok=True)
"""
try:
os.makedirs(dirpath)
except OSError as e:
if e.errno == errno.EEXIST:
pass
else:
raise
def download_url(url, root, filename=None, md5=None):
"""Download a file from a url and place it in root.
Args:
url (str): URL to download file from
root (str): Directory to place downloaded file in
filename (str): Name to save the file under. If None, use the basename of the URL
md5 (str): MD5 checksum of the download. If None, do not check
"""
from six.moves import urllib
root = os.path.expanduser(root)
if not filename:
filename = os.path.basename(url)
fpath = os.path.join(root, filename)
makedir_exist_ok(root)
# downloads file
if os.path.isfile(fpath) and check_integrity(fpath, md5):
print('Using downloaded and verified file: ' + fpath)
else:
try:
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
)
except OSError:
if url[:5] == 'https':
url = url.replace('https:', 'http:')
print('Failed download. Trying https -> http instead.'
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater(tqdm(unit='B', unit_scale=True))
)
def list_dir(root, prefix=False):
"""List all directories at a given root
Args:
root (str): Path to directory whose folders need to be listed
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the directories found
"""
root = os.path.expanduser(root)
directories = list(
filter(
lambda p: os.path.isdir(os.path.join(root, p)),
os.listdir(root)
)
)
if prefix is True:
directories = [os.path.join(root, d) for d in directories]
return directories
def list_files(root, suffix, prefix=False):
"""List all files ending with a suffix at a given root
Args:
root (str): Path to directory whose folders need to be listed
suffix (str or tuple): Suffix of the files to match, e.g. '.png' or ('.jpg', '.png').
It uses the Python "str.endswith" method and is passed directly
prefix (bool, optional): If true, prepends the path to each result, otherwise
only returns the name of the files found
"""
root = os.path.expanduser(root)
files = list(
filter(
lambda p: os.path.isfile(os.path.join(root, p)) and p.endswith(suffix),
os.listdir(root)
)
)
if prefix is True:
files = [os.path.join(root, d) for d in files]
return files
\ No newline at end of file
import os
import sys
import tarfile
import collections
import torch.utils.data as data
import shutil
import numpy as np
from PIL import Image
from torchvision.datasets.utils import download_url, check_integrity
DATASET_YEAR_DICT = {
'2012': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2012/VOCtrainval_11-May-2012.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '6cd6e144f989b92b3379bac3b3de84fd',
'base_dir': 'VOCdevkit/VOC2012'
},
'2011': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2011/VOCtrainval_25-May-2011.tar',
'filename': 'VOCtrainval_25-May-2011.tar',
'md5': '6c3384ef61512963050cb5d687e5bf1e',
'base_dir': 'TrainVal/VOCdevkit/VOC2011'
},
'2010': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2010/VOCtrainval_03-May-2010.tar',
'filename': 'VOCtrainval_03-May-2010.tar',
'md5': 'da459979d0c395079b5c75ee67908abb',
'base_dir': 'VOCdevkit/VOC2010'
},
'2009': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2009/VOCtrainval_11-May-2009.tar',
'filename': 'VOCtrainval_11-May-2009.tar',
'md5': '59065e4b188729180974ef6572f6a212',
'base_dir': 'VOCdevkit/VOC2009'
},
'2008': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2008/VOCtrainval_14-Jul-2008.tar',
'filename': 'VOCtrainval_11-May-2012.tar',
'md5': '2629fa636546599198acfcfbfcf1904a',
'base_dir': 'VOCdevkit/VOC2008'
},
'2007': {
'url': 'http://host.robots.ox.ac.uk/pascal/VOC/voc2007/VOCtrainval_06-Nov-2007.tar',
'filename': 'VOCtrainval_06-Nov-2007.tar',
'md5': 'c52e279531787c972589f7e41ab4ae64',
'base_dir': 'VOCdevkit/VOC2007'
}
}
def voc_cmap(N=256, normalized=False):
def bitget(byteval, idx):
return ((byteval & (1 << idx)) != 0)
dtype = 'float32' if normalized else 'uint8'
cmap = np.zeros((N, 3), dtype=dtype)
for i in range(N):
r = g = b = 0
c = i
for j in range(8):
r = r | (bitget(c, 0) << 7-j)
g = g | (bitget(c, 1) << 7-j)
b = b | (bitget(c, 2) << 7-j)
c = c >> 3
cmap[i] = np.array([r, g, b])
cmap = cmap/255 if normalized else cmap
return cmap
class VOCSegmentation(data.Dataset):
"""`Pascal VOC <http://host.robots.ox.ac.uk/pascal/VOC/>`_ Segmentation Dataset.
Args:
root (string): Root directory of the VOC Dataset.
year (string, optional): The dataset year, supports years 2007 to 2012.
image_set (string, optional): Select the image_set to use, ``train``, ``trainval`` or ``val``
download (bool, optional): If true, downloads the dataset from the internet and
puts it in root directory. If dataset is already downloaded, it is not
downloaded again.
transform (callable, optional): A function/transform that takes in an PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
"""
cmap = voc_cmap()
def __init__(self,
root,
year='2012',
image_set='train',
download=False,
transform=None):
is_aug=False
if year=='2012_aug':
is_aug = True
year = '2012'
self.root = os.path.expanduser(root)
self.year = year
self.url = DATASET_YEAR_DICT[year]['url']
self.filename = DATASET_YEAR_DICT[year]['filename']
self.md5 = DATASET_YEAR_DICT[year]['md5']
self.transform = transform
self.image_set = image_set
base_dir = DATASET_YEAR_DICT[year]['base_dir']
voc_root = os.path.join(self.root, base_dir)
image_dir = os.path.join(voc_root, 'JPEGImages')
if download:
download_extract(self.url, self.root, self.filename, self.md5)
if not os.path.isdir(voc_root):
raise RuntimeError('Dataset not found or corrupted.' +
' You can use download=True to download it')
if is_aug and image_set=='train':
mask_dir = os.path.join(voc_root, 'SegmentationClassAug')
assert os.path.exists(mask_dir), "SegmentationClassAug not found, please refer to README.md and prepare it manually"
split_f = os.path.join( self.root, 'train_aug.txt')#'./datasets/data/train_aug.txt'
else:
mask_dir = os.path.join(voc_root, 'SegmentationClass')
splits_dir = os.path.join(voc_root, 'ImageSets/Segmentation')
split_f = os.path.join(splits_dir, image_set.rstrip('\n') + '.txt')
if not os.path.exists(split_f):
raise ValueError(
'Wrong image_set entered! Please use image_set="train" '
'or image_set="trainval" or image_set="val"')
with open(os.path.join(split_f), "r") as f:
file_names = [x.strip() for x in f.readlines()]
self.images = [os.path.join(image_dir, x + ".jpg") for x in file_names]
self.masks = [os.path.join(mask_dir, x + ".png") for x in file_names]
assert (len(self.images) == len(self.masks))
def __getitem__(self, index):
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is the image segmentation.
"""
img = Image.open(self.images[index]).convert('RGB')
target = Image.open(self.masks[index])
if self.transform is not None:
img, target = self.transform(img, target)
return img, target
def __len__(self):
return len(self.images)
@classmethod
def decode_target(cls, mask):
"""decode semantic mask to RGB image"""
return cls.cmap[mask]
def download_extract(url, root, filename, md5):
download_url(url, root, filename, md5)
with tarfile.open(os.path.join(root, filename), "r") as tar:
tar.extractall(path=root)
\ No newline at end of file
image.png

418 KB

from tqdm import tqdm
import network
import utils
import os
import random
import argparse
import numpy as np
from torch.utils import data
from datasets import VOCSegmentation, Cityscapes
from utils import ext_transforms as et
from metrics import StreamSegMetrics
import torch
import torch.nn as nn
from utils.visualizer import Visualizer
from PIL import Image
import matplotlib
import matplotlib.pyplot as plt
def get_argparser():
parser = argparse.ArgumentParser()
# Datset Options
parser.add_argument("--data_root", type=str, default='./datasets/data',
help="path to Dataset")
parser.add_argument("--dataset", type=str, default='voc',
choices=['voc', 'cityscapes'], help='Name of dataset')
parser.add_argument("--num_classes", type=int, default=None,
help="num classes (default: None)")
# Deeplab Options
available_models = sorted(name for name in network.modeling.__dict__ if name.islower() and \
not (name.startswith("__") or name.startswith('_')) and callable(
network.modeling.__dict__[name])
)
parser.add_argument("--model", type=str, default='deeplabv3plus_mobilenet',
choices=available_models, help='model name')
parser.add_argument("--separable_conv", action='store_true', default=False,
help="apply separable conv to decoder and aspp")
parser.add_argument("--output_stride", type=int, default=16, choices=[8, 16])
# Train Options
parser.add_argument("--test_only", action='store_true', default=False)
parser.add_argument("--save_val_results", action='store_true', default=False,
help="save segmentation results to \"./results\"")
parser.add_argument("--total_itrs", type=int, default=30e3,
help="epoch number (default: 30k)")
parser.add_argument("--lr", type=float, default=0.01,
help="learning rate (default: 0.01)")
parser.add_argument("--lr_policy", type=str, default='poly', choices=['poly', 'step'],
help="learning rate scheduler policy")
parser.add_argument("--step_size", type=int, default=10000)
parser.add_argument("--crop_val", action='store_true', default=False,
help='crop validation (default: False)')
parser.add_argument("--batch_size", type=int, default=16,
help='batch size (default: 16)')
parser.add_argument("--val_batch_size", type=int, default=4,
help='batch size for validation (default: 4)')
parser.add_argument("--crop_size", type=int, default=513)
parser.add_argument("--ckpt", default=None, type=str,
help="restore from checkpoint")
parser.add_argument("--continue_training", action='store_true', default=False)
parser.add_argument("--loss_type", type=str, default='cross_entropy',
choices=['cross_entropy', 'focal_loss'], help="loss type (default: False)")
parser.add_argument("--gpu_id", type=str, default='0',
help="GPU ID")
parser.add_argument("--weight_decay", type=float, default=1e-4,
help='weight decay (default: 1e-4)')
parser.add_argument("--random_seed", type=int, default=1,
help="random seed (default: 1)")
parser.add_argument("--print_interval", type=int, default=10,
help="print interval of loss (default: 10)")
parser.add_argument("--val_interval", type=int, default=100,
help="epoch interval for eval (default: 100)")
parser.add_argument("--download", action='store_true', default=False,
help="download datasets")
# PASCAL VOC Options
parser.add_argument("--year", type=str, default='2012',
choices=['2012_aug', '2012', '2011', '2009', '2008', '2007'], help='year of VOC')
# Visdom options
parser.add_argument("--enable_vis", action='store_true', default=False,
help="use visdom for visualization")
parser.add_argument("--vis_port", type=str, default='13570',
help='port for visdom')
parser.add_argument("--vis_env", type=str, default='main',
help='env for visdom')
parser.add_argument("--vis_num_samples", type=int, default=8,
help='number of samples for visualization (default: 8)')
return parser
def get_dataset(opts):
""" Dataset And Augmentation
"""
if opts.dataset == 'voc':
train_transform = et.ExtCompose([
# et.ExtResize(size=opts.crop_size),
et.ExtRandomScale((0.5, 2.0)),
et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size), pad_if_needed=True),
et.ExtRandomHorizontalFlip(),
et.ExtToTensor(),
et.ExtNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
if opts.crop_val:
val_transform = et.ExtCompose([
et.ExtResize(opts.crop_size),
et.ExtCenterCrop(opts.crop_size),
et.ExtToTensor(),
et.ExtNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
else:
val_transform = et.ExtCompose([
et.ExtToTensor(),
et.ExtNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
image_set='train', download=opts.download, transform=train_transform)
val_dst = VOCSegmentation(root=opts.data_root, year=opts.year,
image_set='val', download=False, transform=val_transform)
if opts.dataset == 'cityscapes':
train_transform = et.ExtCompose([
# et.ExtResize( 512 ),
et.ExtRandomCrop(size=(opts.crop_size, opts.crop_size)),
et.ExtColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
et.ExtRandomHorizontalFlip(),
et.ExtToTensor(),
et.ExtNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
val_transform = et.ExtCompose([
# et.ExtResize( 512 ),
et.ExtToTensor(),
et.ExtNormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]),
])
train_dst = Cityscapes(root=opts.data_root,
split='train', transform=train_transform)
val_dst = Cityscapes(root=opts.data_root,
split='val', transform=val_transform)
return train_dst, val_dst
def validate(opts, model, loader, device, metrics, ret_samples_ids=None):
"""Do validation and return specified samples"""
metrics.reset()
ret_samples = []
if opts.save_val_results:
if not os.path.exists('results'):
os.mkdir('results')
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
img_id = 0
with torch.no_grad():
for i, (images, labels) in tqdm(enumerate(loader)):
images = images.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.long)
outputs = model(images)
preds = outputs.detach().max(dim=1)[1].cpu().numpy()
targets = labels.cpu().numpy()
metrics.update(targets, preds)
if ret_samples_ids is not None and i in ret_samples_ids: # get vis samples
ret_samples.append(
(images[0].detach().cpu().numpy(), targets[0], preds[0]))
if opts.save_val_results:
for i in range(len(images)):
image = images[i].detach().cpu().numpy()
target = targets[i]
pred = preds[i]
image = (denorm(image) * 255).transpose(1, 2, 0).astype(np.uint8)
target = loader.dataset.decode_target(target).astype(np.uint8)
pred = loader.dataset.decode_target(pred).astype(np.uint8)
Image.fromarray(image).save('results/%d_image.png' % img_id)
Image.fromarray(target).save('results/%d_target.png' % img_id)
Image.fromarray(pred).save('results/%d_pred.png' % img_id)
fig = plt.figure()
plt.imshow(image)
plt.axis('off')
plt.imshow(pred, alpha=0.7)
ax = plt.gca()
ax.xaxis.set_major_locator(matplotlib.ticker.NullLocator())
ax.yaxis.set_major_locator(matplotlib.ticker.NullLocator())
plt.savefig('results/%d_overlay.png' % img_id, bbox_inches='tight', pad_inches=0)
plt.close()
img_id += 1
score = metrics.get_results()
return score, ret_samples
def main():
opts = get_argparser().parse_args()
if opts.dataset.lower() == 'voc':
opts.num_classes = 21
elif opts.dataset.lower() == 'cityscapes':
opts.num_classes = 19
# Setup visualization
vis = Visualizer(port=opts.vis_port,
env=opts.vis_env) if opts.enable_vis else None
if vis is not None: # display options
vis.vis_table("Options", vars(opts))
os.environ['CUDA_VISIBLE_DEVICES'] = opts.gpu_id
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Device: %s" % device)
# Setup random seed
torch.manual_seed(opts.random_seed)
np.random.seed(opts.random_seed)
random.seed(opts.random_seed)
# Setup dataloader
if opts.dataset == 'voc' and not opts.crop_val:
opts.val_batch_size = 1
train_dst, val_dst = get_dataset(opts)
train_loader = data.DataLoader(
train_dst, batch_size=opts.batch_size, shuffle=True, num_workers=2,
drop_last=True) # drop_last=True to ignore single-image batches.
val_loader = data.DataLoader(
val_dst, batch_size=opts.val_batch_size, shuffle=True, num_workers=2)
print("Dataset: %s, Train set: %d, Val set: %d" %
(opts.dataset, len(train_dst), len(val_dst)))
# Set up model (all models are 'constructed at network.modeling)
model = network.modeling.__dict__[opts.model](num_classes=opts.num_classes, output_stride=opts.output_stride)
if opts.separable_conv and 'plus' in opts.model:
network.convert_to_separable_conv(model.classifier)
utils.set_bn_momentum(model.backbone, momentum=0.01)
# Set up metrics
metrics = StreamSegMetrics(opts.num_classes)
# Set up optimizer
optimizer = torch.optim.SGD(params=[
{'params': model.backbone.parameters(), 'lr': 0.1 * opts.lr},
{'params': model.classifier.parameters(), 'lr': opts.lr},
], lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
# optimizer = torch.optim.SGD(params=model.parameters(), lr=opts.lr, momentum=0.9, weight_decay=opts.weight_decay)
# torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.lr_decay_step, gamma=opts.lr_decay_factor)
if opts.lr_policy == 'poly':
scheduler = utils.PolyLR(optimizer, opts.total_itrs, power=0.9)
elif opts.lr_policy == 'step':
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=opts.step_size, gamma=0.1)
# Set up criterion
# criterion = utils.get_loss(opts.loss_type)
if opts.loss_type == 'focal_loss':
criterion = utils.FocalLoss(ignore_index=255, size_average=True)
elif opts.loss_type == 'cross_entropy':
criterion = nn.CrossEntropyLoss(ignore_index=255, reduction='mean')
def save_ckpt(path):
""" save current model
"""
torch.save({
"cur_itrs": cur_itrs,
"model_state": model.module.state_dict(),
"optimizer_state": optimizer.state_dict(),
"scheduler_state": scheduler.state_dict(),
"best_score": best_score,
}, path)
print("Model saved as %s" % path)
utils.mkdir('checkpoints')
# Restore
best_score = 0.0
cur_itrs = 0
cur_epochs = 0
if opts.ckpt is not None and os.path.isfile(opts.ckpt):
# https://github.com/VainF/DeepLabV3Plus-Pytorch/issues/8#issuecomment-605601402, @PytaichukBohdan
checkpoint = torch.load(opts.ckpt, map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state"])
model = nn.DataParallel(model)
model.to(device)
if opts.continue_training:
optimizer.load_state_dict(checkpoint["optimizer_state"])
scheduler.load_state_dict(checkpoint["scheduler_state"])
cur_itrs = checkpoint["cur_itrs"]
best_score = checkpoint['best_score']
print("Training state restored from %s" % opts.ckpt)
print("Model restored from %s" % opts.ckpt)
del checkpoint # free memory
else:
print("[!] Retrain")
model = nn.DataParallel(model)
model.to(device)
# ========== Train Loop ==========#
vis_sample_id = np.random.randint(0, len(val_loader), opts.vis_num_samples,
np.int32) if opts.enable_vis else None # sample idxs for visualization
denorm = utils.Denormalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # denormalization for ori images
if opts.test_only:
model.eval()
val_score, ret_samples = validate(
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics, ret_samples_ids=vis_sample_id)
print(metrics.to_str(val_score))
return
interval_loss = 0
while True: # cur_itrs < opts.total_itrs:
# ===== Train =====
model.train()
cur_epochs += 1
for (images, labels) in train_loader:
cur_itrs += 1
images = images.to(device, dtype=torch.float32)
labels = labels.to(device, dtype=torch.long)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
np_loss = loss.detach().cpu().numpy()
interval_loss += np_loss
if vis is not None:
vis.vis_scalar('Loss', cur_itrs, np_loss)
if (cur_itrs) % 10 == 0:
interval_loss = interval_loss / 10
print("Epoch %d, Itrs %d/%d, Loss=%f" %
(cur_epochs, cur_itrs, opts.total_itrs, interval_loss))
interval_loss = 0.0
if (cur_itrs) % opts.val_interval == 0:
save_ckpt('checkpoints/latest_%s_%s_os%d.pth' %
(opts.model, opts.dataset, opts.output_stride))
print("validation...")
model.eval()
val_score, ret_samples = validate(
opts=opts, model=model, loader=val_loader, device=device, metrics=metrics,
ret_samples_ids=vis_sample_id)
print(metrics.to_str(val_score))
if val_score['Mean IoU'] > best_score: # save best model
best_score = val_score['Mean IoU']
save_ckpt('checkpoints/best_%s_%s_os%d.pth' %
(opts.model, opts.dataset, opts.output_stride))
if vis is not None: # visualize validation score and samples
vis.vis_scalar("[Val] Overall Acc", cur_itrs, val_score['Overall Acc'])
vis.vis_scalar("[Val] Mean IoU", cur_itrs, val_score['Mean IoU'])
vis.vis_table("[Val] Class IoU", val_score['Class IoU'])
for k, (img, target, lbl) in enumerate(ret_samples):
img = (denorm(img) * 255).astype(np.uint8)
target = train_dst.decode_target(target).transpose(2, 0, 1).astype(np.uint8)
lbl = train_dst.decode_target(lbl).transpose(2, 0, 1).astype(np.uint8)
concat_img = np.concatenate((img, target, lbl), axis=2) # concat along width
vis.vis_image('Sample %d' % k, concat_img)
model.train()
scheduler.step()
if cur_itrs >= opts.total_itrs:
return
if __name__ == '__main__':
main()
from .stream_metrics import StreamSegMetrics, AverageMeter
import numpy as np
from sklearn.metrics import confusion_matrix
class _StreamMetrics(object):
def __init__(self):
""" Overridden by subclasses """
raise NotImplementedError()
def update(self, gt, pred):
""" Overridden by subclasses """
raise NotImplementedError()
def get_results(self):
""" Overridden by subclasses """
raise NotImplementedError()
def to_str(self, metrics):
""" Overridden by subclasses """
raise NotImplementedError()
def reset(self):
""" Overridden by subclasses """
raise NotImplementedError()
class StreamSegMetrics(_StreamMetrics):
"""
Stream Metrics for Semantic Segmentation Task
"""
def __init__(self, n_classes):
self.n_classes = n_classes
self.confusion_matrix = np.zeros((n_classes, n_classes))
def update(self, label_trues, label_preds):
for lt, lp in zip(label_trues, label_preds):
self.confusion_matrix += self._fast_hist( lt.flatten(), lp.flatten() )
@staticmethod
def to_str(results):
string = "\n"
for k, v in results.items():
if k!="Class IoU":
string += "%s: %f\n"%(k, v)
#string+='Class IoU:\n'
#for k, v in results['Class IoU'].items():
# string += "\tclass %d: %f\n"%(k, v)
return string
def _fast_hist(self, label_true, label_pred):
mask = (label_true >= 0) & (label_true < self.n_classes)
hist = np.bincount(
self.n_classes * label_true[mask].astype(int) + label_pred[mask],
minlength=self.n_classes ** 2,
).reshape(self.n_classes, self.n_classes)
return hist
def get_results(self):
"""Returns accuracy score evaluation result.
- overall accuracy
- mean accuracy
- mean IU
- fwavacc
"""
hist = self.confusion_matrix
acc = np.diag(hist).sum() / hist.sum()
acc_cls = np.diag(hist) / hist.sum(axis=1)
acc_cls = np.nanmean(acc_cls)
iu = np.diag(hist) / (hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist))
mean_iu = np.nanmean(iu)
freq = hist.sum(axis=1) / hist.sum()
fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
cls_iu = dict(zip(range(self.n_classes), iu))
return {
"Overall Acc": acc,
"Mean Acc": acc_cls,
"FreqW Acc": fwavacc,
"Mean IoU": mean_iu,
"Class IoU": cls_iu,
}
def reset(self):
self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))
class AverageMeter(object):
"""Computes average values"""
def __init__(self):
self.book = dict()
def reset_all(self):
self.book.clear()
def reset(self, id):
item = self.book.get(id, None)
if item is not None:
item[0] = 0
item[1] = 0
def update(self, id, val):
record = self.book.get(id, None)
if record is None:
self.book[id] = [val, 1]
else:
record[0]+=val
record[1]+=1
def get_results(self, id):
record = self.book.get(id, None)
assert record is not None
return record[0] / record[1]
# 模型唯一标识
modelCode=1821
# 模型名称
modelName=DeepLabV3-Plus
# 模型描述
modelDescription=DeepLabv3Plus是2025年医学图像分割领域最火的模型之一,结合了UNet的编码器-解码器结构和DeepLabv3的ASPP模块,适用于像素级图像分割任务
# 应用场景
processType=推理,训练
# 算法类别
appScenario=图像分割
# 框架类型
frameType=pytorch
# 加速卡类型
accelerateType=K100AI
\ No newline at end of file
from .modeling import *
from ._deeplab import convert_to_separable_conv
\ No newline at end of file
import torch
from torch import nn
from torch.nn import functional as F
from .utils import _SimpleSegmentationModel
__all__ = ["DeepLabV3"]
class DeepLabV3(_SimpleSegmentationModel):
"""
Implements DeepLabV3 model from
`"Rethinking Atrous Convolution for Semantic Image Segmentation"
<https://arxiv.org/abs/1706.05587>`_.
Arguments:
backbone (nn.Module): the network used to compute the features for the model.
The backbone should return an OrderedDict[Tensor], with the key being
"out" for the last feature map used, and "aux" if an auxiliary classifier
is used.
classifier (nn.Module): module that takes the "out" element returned from
the backbone and returns a dense prediction.
aux_classifier (nn.Module, optional): auxiliary classifier used during training
"""
pass
class DeepLabHeadV3Plus(nn.Module):
def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
super(DeepLabHeadV3Plus, self).__init__()
self.project = nn.Sequential(
nn.Conv2d(low_level_channels, 48, 1, bias=False),
nn.BatchNorm2d(48),
nn.ReLU(inplace=True),
)
self.aspp = ASPP(in_channels, aspp_dilate)
self.classifier = nn.Sequential(
nn.Conv2d(304, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)
self._init_weight()
def forward(self, feature):
low_level_feature = self.project( feature['low_level'] )
output_feature = self.aspp(feature['out'])
output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class DeepLabHead(nn.Module):
def __init__(self, in_channels, num_classes, aspp_dilate=[12, 24, 36]):
super(DeepLabHead, self).__init__()
self.classifier = nn.Sequential(
ASPP(in_channels, aspp_dilate),
nn.Conv2d(256, 256, 3, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, num_classes, 1)
)
self._init_weight()
def forward(self, feature):
return self.classifier( feature['out'] )
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class AtrousSeparableConvolution(nn.Module):
""" Atrous Separable Convolution
"""
def __init__(self, in_channels, out_channels, kernel_size,
stride=1, padding=0, dilation=1, bias=True):
super(AtrousSeparableConvolution, self).__init__()
self.body = nn.Sequential(
# Separable Conv
nn.Conv2d( in_channels, in_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias, groups=in_channels ),
# PointWise Conv
nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias),
)
self._init_weight()
def forward(self, x):
return self.body(x)
def _init_weight(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
class ASPPConv(nn.Sequential):
def __init__(self, in_channels, out_channels, dilation):
modules = [
nn.Conv2d(in_channels, out_channels, 3, padding=dilation, dilation=dilation, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
]
super(ASPPConv, self).__init__(*modules)
class ASPPPooling(nn.Sequential):
def __init__(self, in_channels, out_channels):
super(ASPPPooling, self).__init__(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True))
def forward(self, x):
size = x.shape[-2:]
x = super(ASPPPooling, self).forward(x)
return F.interpolate(x, size=size, mode='bilinear', align_corners=False)
class ASPP(nn.Module):
def __init__(self, in_channels, atrous_rates):
super(ASPP, self).__init__()
out_channels = 256
modules = []
modules.append(nn.Sequential(
nn.Conv2d(in_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)))
rate1, rate2, rate3 = tuple(atrous_rates)
modules.append(ASPPConv(in_channels, out_channels, rate1))
modules.append(ASPPConv(in_channels, out_channels, rate2))
modules.append(ASPPConv(in_channels, out_channels, rate3))
modules.append(ASPPPooling(in_channels, out_channels))
self.convs = nn.ModuleList(modules)
self.project = nn.Sequential(
nn.Conv2d(5 * out_channels, out_channels, 1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Dropout(0.1),)
def forward(self, x):
res = []
for conv in self.convs:
res.append(conv(x))
res = torch.cat(res, dim=1)
return self.project(res)
def convert_to_separable_conv(module):
new_module = module
if isinstance(module, nn.Conv2d) and module.kernel_size[0]>1:
new_module = AtrousSeparableConvolution(module.in_channels,
module.out_channels,
module.kernel_size,
module.stride,
module.padding,
module.dilation,
module.bias)
for name, child in module.named_children():
new_module.add_module(name, convert_to_separable_conv(child))
return new_module
\ No newline at end of file
from . import resnet
from . import mobilenetv2
from . import hrnetv2
from . import xception
import torch
from torch import nn
import torch.nn.functional as F
import os
__all__ = ['HRNet', 'hrnetv2_48', 'hrnetv2_32']
# Checkpoint path of pre-trained backbone (edit to your path). Download backbone pretrained model hrnetv2-32 @
# https://drive.google.com/file/d/1NxCK7Zgn5PmeS7W1jYLt5J9E0RRZ2oyF/view?usp=sharing .Personally, I added the backbone
# weights to the folder /checkpoints
model_urls = {
'hrnetv2_32': './checkpoints/model_best_epoch96_edit.pth',
'hrnetv2_48': None
}
def check_pth(arch):
CKPT_PATH = model_urls[arch]
if os.path.exists(CKPT_PATH):
print(f"Backbone HRNet Pretrained weights at: {CKPT_PATH}, only usable for HRNetv2-32")
else:
print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
return CKPT_PATH
# HRNetv2-48 not available yet, but you can train the whole model from scratch.
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.downsample = downsample
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class StageModule(nn.Module):
def __init__(self, stage, output_branches, c):
super(StageModule, self).__init__()
self.number_of_branches = stage # number of branches is equivalent to the stage configuration.
self.output_branches = output_branches
self.branches = nn.ModuleList()
# Note: Resolution + Number of channels maintains the same throughout respective branch.
for i in range(self.number_of_branches): # Stage scales with the number of branches. Ex: Stage 2 -> 2 branch
channels = c * (2 ** i) # Scale channels by 2x for branch with lower resolution,
# Paper does x4 basic block for each forward sequence in each branch (x4 basic block considered as a block)
branch = nn.Sequential(*[BasicBlock(channels, channels) for _ in range(4)])
self.branches.append(branch) # list containing all forward sequence of individual branches.
# For each branch requires repeated fusion with all other branches after passing through x4 basic blocks.
self.fuse_layers = nn.ModuleList()
for branch_output_number in range(self.output_branches):
self.fuse_layers.append(nn.ModuleList())
for branch_number in range(self.number_of_branches):
if branch_number == branch_output_number:
self.fuse_layers[-1].append(nn.Sequential()) # Used in place of "None" because it is callable
elif branch_number > branch_output_number:
self.fuse_layers[-1].append(nn.Sequential(
nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=1, stride=1,
bias=False),
nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True),
nn.Upsample(scale_factor=(2.0 ** (branch_number - branch_output_number)), mode='nearest'),
))
elif branch_number < branch_output_number:
downsampling_fusion = []
for _ in range(branch_output_number - branch_number - 1):
downsampling_fusion.append(nn.Sequential(
nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_number), kernel_size=3, stride=2,
padding=1,
bias=False),
nn.BatchNorm2d(c * (2 ** branch_number), eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True),
nn.ReLU(inplace=True),
))
downsampling_fusion.append(nn.Sequential(
nn.Conv2d(c * (2 ** branch_number), c * (2 ** branch_output_number), kernel_size=3,
stride=2, padding=1,
bias=False),
nn.BatchNorm2d(c * (2 ** branch_output_number), eps=1e-05, momentum=0.1, affine=True,
track_running_stats=True),
))
self.fuse_layers[-1].append(nn.Sequential(*downsampling_fusion))
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
# input to each stage is a list of inputs for each branch
x = [branch(branch_input) for branch, branch_input in zip(self.branches, x)]
x_fused = []
for branch_output_index in range(
self.output_branches): # Amount of output branches == total length of fusion layers
for input_index in range(self.number_of_branches): # The inputs of other branches to be fused.
if input_index == 0:
x_fused.append(self.fuse_layers[branch_output_index][input_index](x[input_index]))
else:
x_fused[branch_output_index] = x_fused[branch_output_index] + self.fuse_layers[branch_output_index][
input_index](x[input_index])
# After fusing all streams together, you will need to pass the fused layers
for i in range(self.output_branches):
x_fused[i] = self.relu(x_fused[i])
return x_fused # returning a list of fused outputs
class HRNet(nn.Module):
def __init__(self, c=48, num_blocks=[1, 4, 3], num_classes=1000):
super(HRNet, self).__init__()
# Stem:
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(64, eps=1e-05, affine=True, track_running_stats=True)
self.relu = nn.ReLU(inplace=True)
# Stage 1:
downsample = nn.Sequential(
nn.Conv2d(64, 256, kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(256, eps=1e-05, affine=True, track_running_stats=True),
)
# Note that bottleneck module will expand the output channels according to the output channels*block.expansion
bn_expansion = Bottleneck.expansion # The channel expansion is set in the bottleneck class.
self.layer1 = nn.Sequential(
Bottleneck(64, 64, downsample=downsample), # Input is 64 for first module connection
Bottleneck(bn_expansion * 64, 64),
Bottleneck(bn_expansion * 64, 64),
Bottleneck(bn_expansion * 64, 64),
)
# Transition 1 - Creation of the first two branches (one full and one half resolution)
# Need to transition into high resolution stream and mid resolution stream
self.transition1 = nn.ModuleList([
nn.Sequential(
nn.Conv2d(256, c, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(c, eps=1e-05, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
),
nn.Sequential(nn.Sequential( # Double Sequential to fit with official pretrained weights
nn.Conv2d(256, c * 2, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(c * 2, eps=1e-05, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
)),
])
# Stage 2:
number_blocks_stage2 = num_blocks[0]
self.stage2 = nn.Sequential(
*[StageModule(stage=2, output_branches=2, c=c) for _ in range(number_blocks_stage2)])
# Transition 2 - Creation of the third branch (1/4 resolution)
self.transition2 = self._make_transition_layers(c, transition_number=2)
# Stage 3:
number_blocks_stage3 = num_blocks[1] # number blocks you want to create before fusion
self.stage3 = nn.Sequential(
*[StageModule(stage=3, output_branches=3, c=c) for _ in range(number_blocks_stage3)])
# Transition - Creation of the fourth branch (1/8 resolution)
self.transition3 = self._make_transition_layers(c, transition_number=3)
# Stage 4:
number_blocks_stage4 = num_blocks[2] # number blocks you want to create before fusion
self.stage4 = nn.Sequential(
*[StageModule(stage=4, output_branches=4, c=c) for _ in range(number_blocks_stage4)])
# Classifier (extra module if want to use for classification):
# pool, reduce dimensionality, flatten, connect to linear layer for classification:
out_channels = sum([c * 2 ** i for i in range(len(num_blocks)+1)]) # total output channels of HRNetV2
pool_feature_map = 8
self.bn_classifier = nn.Sequential(
nn.Conv2d(out_channels, out_channels // 4, kernel_size=1, bias=False),
nn.BatchNorm2d(out_channels // 4, eps=1e-05, affine=True, track_running_stats=True),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(pool_feature_map),
nn.Flatten(),
nn.Linear(pool_feature_map * pool_feature_map * (out_channels // 4), num_classes),
)
@staticmethod
def _make_transition_layers(c, transition_number):
return nn.Sequential(
nn.Conv2d(c * (2 ** (transition_number - 1)), c * (2 ** transition_number), kernel_size=3, stride=2,
padding=1, bias=False),
nn.BatchNorm2d(c * (2 ** transition_number), eps=1e-05, affine=True,
track_running_stats=True),
nn.ReLU(inplace=True),
)
def forward(self, x):
# Stem:
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
# Stage 1
x = self.layer1(x)
x = [trans(x) for trans in self.transition1] # split to 2 branches, form a list.
# Stage 2
x = self.stage2(x)
x.append(self.transition2(x[-1]))
# Stage 3
x = self.stage3(x)
x.append(self.transition3(x[-1]))
# Stage 4
x = self.stage4(x)
# HRNetV2 Example: (follow paper, upsample via bilinear interpolation and to highest resolution size)
output_h, output_w = x[0].size(2), x[0].size(3) # Upsample to size of highest resolution stream
x1 = F.interpolate(x[1], size=(output_h, output_w), mode='bilinear', align_corners=False)
x2 = F.interpolate(x[2], size=(output_h, output_w), mode='bilinear', align_corners=False)
x3 = F.interpolate(x[3], size=(output_h, output_w), mode='bilinear', align_corners=False)
# Upsampling all the other resolution streams and then concatenate all (rather than adding/fusing like HRNetV1)
x = torch.cat([x[0], x1, x2, x3], dim=1)
x = self.bn_classifier(x)
return x
def _hrnet(arch, channels, num_blocks, pretrained, progress, **kwargs):
model = HRNet(channels, num_blocks, **kwargs)
if pretrained:
CKPT_PATH = check_pth(arch)
checkpoint = torch.load(CKPT_PATH)
model.load_state_dict(checkpoint['state_dict'])
return model
def hrnetv2_48(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
w_channels = 48
return _hrnet('hrnetv2_48', w_channels, number_blocks, pretrained, progress,
**kwargs)
def hrnetv2_32(pretrained=False, progress=True, number_blocks=[1, 4, 3], **kwargs):
w_channels = 32
return _hrnet('hrnetv2_32', w_channels, number_blocks, pretrained, progress,
**kwargs)
if __name__ == '__main__':
try:
CKPT_PATH = os.path.join(os.path.abspath("."), '../../checkpoints/hrnetv2_32_model_best_epoch96.pth')
print("--- Running file as MAIN ---")
print(f"Backbone HRNET Pretrained weights as __main__ at: {CKPT_PATH}")
except:
print("No backbone checkpoint found for HRNetv2, please set pretrained=False when calling model")
# Models
model = hrnetv2_32(pretrained=True)
#model = hrnetv2_48(pretrained=False)
if torch.cuda.is_available():
torch.backends.cudnn.deterministic = True
device = torch.device('cuda')
else:
device = torch.device('cpu')
model.to(device)
in_ = torch.ones(1, 3, 768, 768).to(device)
y = model(in_)
print(y.shape)
# Calculate total number of parameters:
# pytorch_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# print(pytorch_total_params)
from torch import nn
try: # for torchvision<0.4
from torchvision.models.utils import load_state_dict_from_url
except: # for torchvision>=0.4
from torch.hub import load_state_dict_from_url
import torch.nn.functional as F
__all__ = ['MobileNetV2', 'mobilenet_v2']
model_urls = {
'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
}
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
class ConvBNReLU(nn.Sequential):
def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
#padding = (kernel_size - 1) // 2
super(ConvBNReLU, self).__init__(
nn.Conv2d(in_planes, out_planes, kernel_size, stride, 0, dilation=dilation, groups=groups, bias=False),
nn.BatchNorm2d(out_planes),
nn.ReLU6(inplace=True)
)
def fixed_padding(kernel_size, dilation):
kernel_size_effective = kernel_size + (kernel_size - 1) * (dilation - 1)
pad_total = kernel_size_effective - 1
pad_beg = pad_total // 2
pad_end = pad_total - pad_beg
return (pad_beg, pad_end, pad_beg, pad_end)
class InvertedResidual(nn.Module):
def __init__(self, inp, oup, stride, dilation, expand_ratio):
super(InvertedResidual, self).__init__()
self.stride = stride
assert stride in [1, 2]
hidden_dim = int(round(inp * expand_ratio))
self.use_res_connect = self.stride == 1 and inp == oup
layers = []
if expand_ratio != 1:
# pw
layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
layers.extend([
# dw
ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
# pw-linear
nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
nn.BatchNorm2d(oup),
])
self.conv = nn.Sequential(*layers)
self.input_padding = fixed_padding( 3, dilation )
def forward(self, x):
x_pad = F.pad(x, self.input_padding)
if self.use_res_connect:
return x + self.conv(x_pad)
else:
return self.conv(x_pad)
class MobileNetV2(nn.Module):
def __init__(self, num_classes=1000, output_stride=8, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
"""
MobileNet V2 main class
Args:
num_classes (int): Number of classes
width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
inverted_residual_setting: Network structure
round_nearest (int): Round the number of channels in each layer to be a multiple of this number
Set to 1 to turn off rounding
"""
super(MobileNetV2, self).__init__()
block = InvertedResidual
input_channel = 32
last_channel = 1280
self.output_stride = output_stride
current_stride = 1
if inverted_residual_setting is None:
inverted_residual_setting = [
# t, c, n, s
[1, 16, 1, 1],
[6, 24, 2, 2],
[6, 32, 3, 2],
[6, 64, 4, 2],
[6, 96, 3, 1],
[6, 160, 3, 2],
[6, 320, 1, 1],
]
# only check the first element, assuming user knows t,c,n,s are required
if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
raise ValueError("inverted_residual_setting should be non-empty "
"or a 4-element list, got {}".format(inverted_residual_setting))
# building first layer
input_channel = _make_divisible(input_channel * width_mult, round_nearest)
self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
features = [ConvBNReLU(3, input_channel, stride=2)]
current_stride *= 2
dilation=1
previous_dilation = 1
# building inverted residual blocks
for t, c, n, s in inverted_residual_setting:
output_channel = _make_divisible(c * width_mult, round_nearest)
previous_dilation = dilation
if current_stride == output_stride:
stride = 1
dilation *= s
else:
stride = s
current_stride *= s
output_channel = int(c * width_mult)
for i in range(n):
if i==0:
features.append(block(input_channel, output_channel, stride, previous_dilation, expand_ratio=t))
else:
features.append(block(input_channel, output_channel, 1, dilation, expand_ratio=t))
input_channel = output_channel
# building last several layers
features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
# make it nn.Sequential
self.features = nn.Sequential(*features)
# building classifier
self.classifier = nn.Sequential(
nn.Dropout(0.2),
nn.Linear(self.last_channel, num_classes),
)
# weight initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out')
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d):
nn.init.ones_(m.weight)
nn.init.zeros_(m.bias)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.zeros_(m.bias)
def forward(self, x):
x = self.features(x)
x = x.mean([2, 3])
x = self.classifier(x)
return x
def mobilenet_v2(pretrained=False, progress=True, **kwargs):
"""
Constructs a MobileNetV2 architecture from
`"MobileNetV2: Inverted Residuals and Linear Bottlenecks" <https://arxiv.org/abs/1801.04381>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
model = MobileNetV2(**kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
progress=progress)
model.load_state_dict(state_dict)
return model
import torch
import torch.nn as nn
try: # for torchvision<0.4
from torchvision.models.utils import load_state_dict_from_url
except: # for torchvision>=0.4
from torch.hub import load_state_dict_from_url
__all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101',
'resnet152', 'resnext50_32x4d', 'resnext101_32x8d',
'wide_resnet50_2', 'wide_resnet101_2']
model_urls = {
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth',
'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth',
'wide_resnet50_2': 'https://download.pytorch.org/models/wide_resnet50_2-95faca4d.pth',
'wide_resnet101_2': 'https://download.pytorch.org/models/wide_resnet101_2-32ee1156.pth',
}
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes, out_planes, stride=1):
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(BasicBlock, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
base_width=64, dilation=1, norm_layer=None):
super(Bottleneck, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
class ResNet(nn.Module):
def __init__(self, block, layers, num_classes=1000, zero_init_residual=False,
groups=1, width_per_group=64, replace_stride_with_dilation=None,
norm_layer=None):
super(ResNet, self).__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
self._norm_layer = norm_layer
self.inplanes = 64
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False]
if len(replace_stride_with_dilation) != 3:
raise ValueError("replace_stride_with_dilation should be None "
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
self.groups = groups
self.base_width = width_per_group
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
bias=False)
self.bn1 = norm_layer(self.inplanes)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(block, 64, layers[0])
self.layer2 = self._make_layer(block, 128, layers[1], stride=2,
dilate=replace_stride_with_dilation[0])
self.layer3 = self._make_layer(block, 256, layers[2], stride=2,
dilate=replace_stride_with_dilation[1])
self.layer4 = self._make_layer(block, 512, layers[3], stride=2,
dilate=replace_stride_with_dilation[2])
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512 * block.expansion, num_classes)
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
# Zero-initialize the last BN in each residual branch,
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
if zero_init_residual:
for m in self.modules():
if isinstance(m, Bottleneck):
nn.init.constant_(m.bn3.weight, 0)
elif isinstance(m, BasicBlock):
nn.init.constant_(m.bn2.weight, 0)
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
norm_layer = self._norm_layer
downsample = None
previous_dilation = self.dilation
if dilate:
self.dilation *= stride
stride = 1
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, planes * block.expansion, stride),
norm_layer(planes * block.expansion),
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
def _resnet(arch, block, layers, pretrained, progress, **kwargs):
model = ResNet(block, layers, **kwargs)
if pretrained:
state_dict = load_state_dict_from_url(model_urls[arch],
progress=progress)
model.load_state_dict(state_dict)
return model
def resnet18(pretrained=False, progress=True, **kwargs):
r"""ResNet-18 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet18', BasicBlock, [2, 2, 2, 2], pretrained, progress,
**kwargs)
def resnet34(pretrained=False, progress=True, **kwargs):
r"""ResNet-34 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet34', BasicBlock, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet50(pretrained=False, progress=True, **kwargs):
r"""ResNet-50 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet50', Bottleneck, [3, 4, 6, 3], pretrained, progress,
**kwargs)
def resnet101(pretrained=False, progress=True, **kwargs):
r"""ResNet-101 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet101', Bottleneck, [3, 4, 23, 3], pretrained, progress,
**kwargs)
def resnet152(pretrained=False, progress=True, **kwargs):
r"""ResNet-152 model from
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return _resnet('resnet152', Bottleneck, [3, 8, 36, 3], pretrained, progress,
**kwargs)
def resnext50_32x4d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-50 32x4d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 4
return _resnet('resnext50_32x4d', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def resnext101_32x8d(pretrained=False, progress=True, **kwargs):
r"""ResNeXt-101 32x8d model from
`"Aggregated Residual Transformation for Deep Neural Networks" <https://arxiv.org/pdf/1611.05431.pdf>`_
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['groups'] = 32
kwargs['width_per_group'] = 8
return _resnet('resnext101_32x8d', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
def wide_resnet50_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-50-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet50_2', Bottleneck, [3, 4, 6, 3],
pretrained, progress, **kwargs)
def wide_resnet101_2(pretrained=False, progress=True, **kwargs):
r"""Wide ResNet-101-2 model from
`"Wide Residual Networks" <https://arxiv.org/pdf/1605.07146.pdf>`_
The model is the same as ResNet except for the bottleneck number of channels
which is twice larger in every block. The number of channels in outer 1x1
convolutions is the same, e.g. last block in ResNet-50 has 2048-512-2048
channels, and in Wide ResNet-50-2 has 2048-1024-2048.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
kwargs['width_per_group'] = 64 * 2
return _resnet('wide_resnet101_2', Bottleneck, [3, 4, 23, 3],
pretrained, progress, **kwargs)
"""
Xception is adapted from https://github.com/Cadene/pretrained-models.pytorch/blob/master/pretrainedmodels/models/xception.py
Ported to pytorch thanks to [tstandley](https://github.com/tstandley/Xception-PyTorch)
@author: tstandley
Adapted by cadene
Creates an Xception Model as defined in:
Francois Chollet
Xception: Deep Learning with Depthwise Separable Convolutions
https://arxiv.org/pdf/1610.02357.pdf
This weights ported from the Keras implementation. Achieves the following performance on the validation set:
Loss:0.9173 Prec@1:78.892 Prec@5:94.292
REMEMBER to set your image size to 3x299x299 for both test and validation
normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5],
std=[0.5, 0.5, 0.5])
The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
"""
from __future__ import print_function, division, absolute_import
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
__all__ = ['xception']
pretrained_settings = {
'xception': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000,
'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
}
}
}
class SeparableConv2d(nn.Module):
def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
super(SeparableConv2d,self).__init__()
self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)
def forward(self,x):
x = self.conv1(x)
x = self.pointwise(x)
return x
class Block(nn.Module):
def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True, dilation=1):
super(Block, self).__init__()
if out_filters != in_filters or strides!=1:
self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
self.skipbn = nn.BatchNorm2d(out_filters)
else:
self.skip=None
rep=[]
filters=in_filters
if grow_first:
rep.append(nn.ReLU(inplace=True))
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=dilation, dilation=dilation, bias=False))
rep.append(nn.BatchNorm2d(out_filters))
filters = out_filters
for i in range(reps-1):
rep.append(nn.ReLU(inplace=True))
rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=dilation,dilation=dilation,bias=False))
rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(nn.ReLU(inplace=True))
rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=dilation,dilation=dilation,bias=False))
rep.append(nn.BatchNorm2d(out_filters))
if not start_with_relu:
rep = rep[1:]
else:
rep[0] = nn.ReLU(inplace=False)
if strides != 1:
rep.append(nn.MaxPool2d(3,strides,1))
self.rep = nn.Sequential(*rep)
def forward(self,inp):
x = self.rep(inp)
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
x+=skip
return x
class Xception(nn.Module):
"""
Xception optimized for the ImageNet dataset, as specified in
https://arxiv.org/pdf/1610.02357.pdf
"""
def __init__(self, num_classes=1000, replace_stride_with_dilation=None):
""" Constructor
Args:
num_classes: number of classes
"""
super(Xception, self).__init__()
self.num_classes = num_classes
self.dilation = 1
if replace_stride_with_dilation is None:
# each element in the tuple indicates if we should replace
# the 2x2 stride with a dilated convolution instead
replace_stride_with_dilation = [False, False, False, False]
if len(replace_stride_with_dilation) != 4:
raise ValueError("replace_stride_with_dilation should be None "
"or a 4-element tuple, got {}".format(replace_stride_with_dilation))
self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) # 1 / 2
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(32,64,3,bias=False)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
#do relu here
self.block1=self._make_block(64,128,2,2,start_with_relu=False,grow_first=True, dilate=replace_stride_with_dilation[0]) # 1 / 4
self.block2=self._make_block(128,256,2,2,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[1]) # 1 / 8
self.block3=self._make_block(256,728,2,2,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2]) # 1 / 16
self.block4=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block5=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block6=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block7=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block8=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block9=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block10=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block11=self._make_block(728,728,3,1,start_with_relu=True,grow_first=True, dilate=replace_stride_with_dilation[2])
self.block12=self._make_block(728,1024,2,2,start_with_relu=True,grow_first=False, dilate=replace_stride_with_dilation[3]) # 1 / 32
self.conv3 = SeparableConv2d(1024,1536,3,1,1, dilation=self.dilation)
self.bn3 = nn.BatchNorm2d(1536)
self.relu3 = nn.ReLU(inplace=True)
#do relu here
self.conv4 = SeparableConv2d(1536,2048,3,1,1, dilation=self.dilation)
self.bn4 = nn.BatchNorm2d(2048)
self.fc = nn.Linear(2048, num_classes)
# #------- init weights --------
# for m in self.modules():
# if isinstance(m, nn.Conv2d):
# n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
# m.weight.data.normal_(0, math.sqrt(2. / n))
# elif isinstance(m, nn.BatchNorm2d):
# m.weight.data.fill_(1)
# m.bias.data.zero_()
# #-----------------------------
def _make_block(self, in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True, dilate=False):
if dilate:
self.dilation *= strides
strides = 1
return Block(in_filters,out_filters,reps,strides,start_with_relu=start_with_relu,grow_first=grow_first, dilation=self.dilation)
def features(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = self.relu1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.conv4(x)
x = self.bn4(x)
return x
def logits(self, features):
x = nn.ReLU(inplace=True)(features)
x = F.adaptive_avg_pool2d(x, (1, 1))
x = x.view(x.size(0), -1)
x = self.last_linear(x)
return x
def forward(self, input):
x = self.features(input)
x = self.logits(x)
return x
def xception(num_classes=1000, pretrained='imagenet', replace_stride_with_dilation=None):
model = Xception(num_classes=num_classes, replace_stride_with_dilation=replace_stride_with_dilation)
if pretrained:
settings = pretrained_settings['xception'][pretrained]
assert num_classes == settings['num_classes'], \
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
model = Xception(num_classes=num_classes, replace_stride_with_dilation=replace_stride_with_dilation)
model.load_state_dict(model_zoo.load_url(settings['url']))
# TODO: ugly
model.last_linear = model.fc
del model.fc
return model
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment