Commit 7339f0b0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import torch
from torch import nn
from torch.autograd import Variable
from torch.nn import functional as F
import torch.utils.data
from torchvision.models.inception import inception_v3
import os
from skimage import io
import cv2
import os
import numpy as np
from scipy.stats import entropy
import torchvision.datasets as dset
import torchvision.transforms as transforms
import argparse
def inception_score(imgs, cuda=True, batch_size=32, resize=False, splits=32):
"""Computes the inception score of the generated images imgs
imgs -- Torch dataset of (3xHxW) numpy images normalized in the range [-1, 1]
cuda -- whether or not to run on GPU
batch_size -- batch size for feeding into Inception v3
splits -- number of splits
"""
N = len(imgs)
assert batch_size > 0
assert N > batch_size
# Set up dtype
if cuda:
dtype = torch.cuda.FloatTensor
else:
if torch.cuda.is_available():
print("WARNING: You have a CUDA device, so you should probably set cuda=True")
dtype = torch.FloatTensor
# Set up dataloader
dataloader = torch.utils.data.DataLoader(imgs, batch_size=batch_size)
# Load inception model
inception_model = inception_v3(pretrained=True, transform_input=False).type(dtype)
inception_model.eval();
up = nn.Upsample(size=(299, 299), mode='bilinear').type(dtype)
def get_pred(x):
if resize:
x = up(x)
x = inception_model(x)
return F.softmax(x).data.cpu().numpy()
# Get predictions
preds = np.zeros((N, 1000))
for i, batch in enumerate(dataloader, 0):
batch = batch.type(dtype)
batchv = Variable(batch)
batch_size_i = batch.size()[0]
preds[i*batch_size:i*batch_size + batch_size_i] = get_pred(batchv)
# Now compute the mean kl-div
split_scores = []
for k in range(splits):
part = preds[k * (N // splits): (k+1) * (N // splits), :]
py = np.mean(part, axis=0)
scores = []
for i in range(part.shape[0]):
pyx = part[i, :]
scores.append(entropy(pyx, py))
split_scores.append(np.exp(np.mean(scores)))
return np.mean(split_scores), np.std(split_scores)
class UnlabeledDataset(torch.utils.data.Dataset):
def __init__(self, folder, transform=None):
self.folder = folder
self.transform = transform
self.image_files = os.listdir(folder)
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_file = self.image_files[idx]
image_path = os.path.join(self.folder, image_file)
image = io.imread(image_path)
if self.transform:
image = self.transform(image)
return image
class IgnoreLabelDataset(torch.utils.data.Dataset):
def __init__(self, orig):
self.orig = orig
def __getitem__(self, index):
return self.orig[index][0]
def __len__(self):
return len(self.orig)
if __name__ == '__main__':
# cifar = dset.CIFAR10(root='data/', download=True,
# transform=transforms.Compose([
# transforms.Resize(32),``
# transforms.ToTensor(),
# transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# ])
# )
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# set args
parser = argparse.ArgumentParser()
parser.add_argument('--data-root', type=str, default='/data/wyli/code/TinyDDPM/Output/unet_busi/Gens/')
args = parser.parse_args()
dataset = UnlabeledDataset(args.data_root, transform=transform)
print ("Calculating Inception Score...")
print (inception_score(dataset, cuda=True, batch_size=1, resize=True, splits=10))
download released_models.zip and unzip here
\ No newline at end of file
pytorch-fid==0.30.0
torch==2.3.0
torchvision==0.18.0
tqdm
timm==0.9.16
scikit-image==0.23.1
\ No newline at end of file
import os
from skimage import io, transform
from skimage.util import img_as_ubyte
import numpy as np
# Define the source and destination directories
src_dir = '/data/wyli/data/CVC-ClinicDB/Original/'
dst_dir = '/data/wyli/data/cvc/images_64/'
os.makedirs(dst_dir, exist_ok=True)
# Get a list of all the image files in the source directory
image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
# Define the size of the crop box
crop_size = np.array([288 ,288])
# Define the size of the resized image
resize_size = (64, 64)
for image_file in image_files:
# Load the image
image = io.imread(os.path.join(src_dir, image_file))
# print(image.shape)
# Calculate the center of the image
center = np.array(image.shape[:2]) // 2
# Calculate the start and end points of the crop box
start = center - crop_size // 2
end = start + crop_size
# Crop the image
cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])
# Resize the cropped image
resized_image = transform.resize(cropped_image, resize_size, mode='reflect')
# Save the resized image to the destination directory
io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image))
\ No newline at end of file
import os
from skimage import io, transform
from skimage.util import img_as_ubyte
import numpy as np
# Define the source and destination directories
src_dir = '/data/wyli/data/busi/images/'
dst_dir = '/data/wyli/data/busi/images_64/'
os.makedirs(dst_dir, exist_ok=True)
# Get a list of all the image files in the source directory
image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
# Define the size of the crop box
crop_size = np.array([400 ,400])
# Define the size of the resized image
# resize_size = (64, 64)
resize_size = (64, 64)
for image_file in image_files:
# Load the image
image = io.imread(os.path.join(src_dir, image_file))
print(image.shape)
# Calculate the center of the image
center = np.array(image.shape[:2]) // 2
# Calculate the start and end points of the crop box
start = center - crop_size // 2
end = start + crop_size
# Crop the image
cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])
# Resize the cropped image
resized_image = transform.resize(cropped_image, resize_size, mode='reflect')
# Save the resized image to the destination directory
io.imsave(os.path.join(dst_dir, image_file), img_as_ubyte(resized_image))
\ No newline at end of file
import os
from skimage import io, transform
from skimage.util import img_as_ubyte
import numpy as np
import random
# Define the source and destination directories
src_dir = '/data/wyli/data/glas/images/'
dst_dir = '/data/wyli/data/glas/images_64/'
os.makedirs(dst_dir, exist_ok=True)
# Get a list of all the image files in the source directory
image_files = [f for f in os.listdir(src_dir) if os.path.isfile(os.path.join(src_dir, f))]
# Define the size of the crop box
crop_size = np.array([64, 64])
# Define the number of crops per image
K = 5
for image_file in image_files:
# Load the image
image = io.imread(os.path.join(src_dir, image_file))
# Get the size of the image
image_size = np.array(image.shape[:2])
for i in range(K):
# Calculate a random start point for the crop box
start = np.array([random.randint(0, image_size[0] - crop_size[0]), random.randint(0, image_size[1] - crop_size[1])])
# Calculate the end point of the crop box
end = start + crop_size
# Crop the image
cropped_image = img_as_ubyte(image[start[0]:end[0], start[1]:end[1]])
# Save the cropped image to the destination directory
io.imsave(os.path.join(dst_dir, f"{image_file}_{i}.png"), cropped_image)
\ No newline at end of file
##!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh
conda activate kan
GPU=0
MODEL='UKan_Hybrid'
EXP_NME='UKan_cvc'
SAVE_ROOT='./Output/'
DATASET='busi'
cd ../
CUDA_VISIBLE_DEVICES=${GPU} python Main.py \
--model ${MODEL} \
--exp_nme ${EXP_NME} \
--batch_size 32 \
--channel 64 \
--dataset ${DATASET} \
--epoch 5000 \
--save_root ${SAVE_ROOT}
# --lr 1e-4
# calcuate FID and IS
CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1
cd inception-score-pytorch
CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1
##!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh
conda activate kan
GPU=0
MODEL='UKan_Hybrid'
EXP_NME='UKan_cvc'
SAVE_ROOT='./Output/'
DATASET='cvc'
cd ../
CUDA_VISIBLE_DEVICES=${GPU} python Main.py \
--model ${MODEL} \
--exp_nme ${EXP_NME} \
--batch_size 32 \
--channel 64 \
--dataset ${DATASET} \
--epoch 1000 \
--save_root ${SAVE_ROOT}
# --lr 1e-4
# calcuate FID and IS
CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1
cd inception-score-pytorch
CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1
##!/bin/bash
source ~/miniconda3/etc/profile.d/conda.sh
conda activate kan
GPU=0
MODEL='UKan_Hybrid'
EXP_NME='UKan_glas'
SAVE_ROOT='./Output/'
DATASET='glas'
cd ../
CUDA_VISIBLE_DEVICES=${GPU} python Main.py \
--model ${MODEL} \
--exp_nme ${EXP_NME} \
--batch_size 32 \
--channel 64 \
--dataset ${DATASET} \
--epoch 1000 \
--save_root ${SAVE_ROOT}
# --lr 1e-4
# calcuate FID and IS
CUDA_VISIBLE_DEVICES=${GPU} python -m pytorch_fid "data/${DATASET}/images_64/" "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/FID.txt" 2>&1
cd inception-score-pytorch
CUDA_VISIBLE_DEVICES=${GPU} python inception_score.py --data-root "${SAVE_ROOT}/${EXP_NME}/Gens" > "${SAVE_ROOT}/${EXP_NME}/IS.txt" 2>&1
# U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation
:pushpin: This is an official PyTorch implementation of **U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**
[[`Project Page`](https://yes-ukan.github.io/)] [[`arXiv`](https://arxiv.org/abs/2406.02918)] [[`BibTeX`](#citation)]
<p align="center">
<img src="./assets/logo_1.png" alt="" width="120" height="120">
</p>
> [**U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation**](https://arxiv.org/abs/2406.02918)<br>
> [Chenxin Li](https://xggnet.github.io/)\*, [Xinyu Liu](https://xinyuliu-jeffrey.github.io/)\*, [Wuyang Li](https://wymancv.github.io/wuyang.github.io/)\*, [Cheng Wang](https://scholar.google.com/citations?user=AM7gvyUAAAAJ&hl=en)\*, [Hengyu Liu](), [Yixuan Yuan](https://www.ee.cuhk.edu.hk/~yxyuan/people/people.htm)<sup>✉</sup><br>The Chinese Univerisity of Hong Kong
We explore the untapped potential of Kolmogorov-Anold Network (aka. KAN) in improving backbones for vision tasks. We investigate, modify and re-design the established U-Net pipeline by integrating the dedicated KAN layers on the tokenized intermediate representation, termed U-KAN. Rigorous medical image segmentation benchmarks verify the superiority of U-KAN by higher accuracy even with less computation cost. We further delved into the potential of U-KAN as an alternative U-Net noise predictor in diffusion models, demonstrating its applicability in generating task-oriented model architectures. These endeavours unveil valuable insights and sheds light on the prospect that with U-KAN, you can make strong backbone for medical image segmentation and generation.
<div align="center">
<img width="100%" alt="UKAN overview" src="assets/framework-1.jpg"/>
</div>
## 📰News
**[2024.6]** Some modifications are done in Seg_UKAN for better performance reproduction. The previous code can be quickly updated by replacing the contents of train.py and archs.py with the new ones.
**[2024.6]** Model checkpoints and training logs are released!
**[2024.6]** Code and paper of U-KAN are released!
## 💡Key Features
- The first effort to incorporate the advantage of emerging KAN to improve established U-Net pipeline to be more **accurate, efficient and interpretable**.
- A Segmentation U-KAN with **tokenized KAN block to effectively steer the KAN operators** to be compatible with the exiting convolution-based designs.
- A Diffusion U-KAN as an **improved noise predictor** demonstrates its potential in backboning generative tasks and broader vision settings.
## 🛠Setup
```bash
git clone https://github.com/CUHK-AIM-Group/U-KAN.git
cd U-KAN
conda create -n ukan python=3.10
conda activate ukan
cd Seg_UKAN && pip install -r requirements.txt
```
**Tips A**: We test the framework using pytorch=1.13.0, and the CUDA compile version=11.6. Other versions should be also fine but not totally ensured.
## 📚Data Preparation
**BUSI**: The dataset can be found [here](https://www.kaggle.com/datasets/aryashah2k/breast-ultrasound-images-dataset).
**GLAS**: The dataset can be found [here](https://websignon.warwick.ac.uk/origin/slogin?shire=https%3A%2F%2Fwarwick.ac.uk%2Fsitebuilder2%2Fshire-read&providerId=urn%3Awarwick.ac.uk%3Asitebuilder2%3Aread%3Aservice&target=https%3A%2F%2Fwarwick.ac.uk%2Ffac%2Fcross_fac%2Ftia%2Fdata%2Fglascontest&status=notloggedin).
<!-- You can directly use the [processed GLAS data]() without further data processing. -->
**CVC-ClinicDB**: The dataset can be found [here](https://www.dropbox.com/s/p5qe9eotetjnbmq/CVC-ClinicDB.rar?e=3&dl=0).
<!-- You can directly use the [processed CVC-ClinicDB data]() without further data processing. -->
We also provide all the [pre-processed dataset](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/ErDlT-t0WoBNlKhBlbYfReYB-iviSCmkNRb1GqZ90oYjJA?e=hrPNWD) without requiring any further data processing. You can directly download and put them into the data dir.
The resulted file structure is as follows.
```
Seg_UKAN
├── inputs
│ ├── busi
│ ├── images
│ ├── malignant (1).png
| ├── ...
| ├── masks
│ ├── 0
│ ├── malignant (1)_mask.png
| ├── ...
│ ├── GLAS
│ ├── images
│ ├── 0.png
| ├── ...
| ├── masks
│ ├── 0
│ ├── 0.png
| ├── ...
│ ├── CVC-ClinicDB
│ ├── images
│ ├── 0.png
| ├── ...
| ├── masks
│ ├── 0
│ ├── 0.png
| ├── ...
```
## 🔖Evaluating Segmentation U-KAN
You can directly evaluate U-KAN from the checkpoint model. Here is an example for quick usage for using our **pre-trained models** in [Segmentation Model Zoo](#segmentation-model-zoo):
1. Download the pre-trained weights and put them to ```{args.output_dir}/{args.name}/model.pth```
2. Run the following scripts to
```bash
cd Seg_UKAN
python val.py --name ${dataset}_UKAN --output_dir [YOUR_OUTPUT_DIR]
```
## ⏳Training Segmentation U-KAN
You can simply train U-KAN on a single GPU by specifing the dataset name ```--dataset``` and input size ```--input_size```.
```bash
cd Seg_UKAN
python train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN --data_dir [YOUR_DATA_DIR]
```
For example, train U-KAN with the resolution of 256x256 with a single GPU on the BUSI dataset in the ```inputs``` dir:
```bash
cd Seg_UKAN
python train.py --arch UKAN --dataset busi --input_w 256 --input_h 256 --name busi_UKAN --data_dir ./inputs
```
Please see Seg_UKAN/scripts.sh for more details.
Note that the resolution of glas is 512x512, differing with other datasets (256x256).
## 🎪Segmentation Model Zoo
We provide all the pre-trained model [checkpoints](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ej6yZBSIrU5Ds9q-gQdhXqwBbpov5_MaWF483uZHm2lccA?e=rmlHMo)
Here is an overview of the released performance&checkpoints. Note that results on a single run and the reported average results in the paper differ.
|Method| Dataset | IoU | F1 | Checkpoints |
|-----|------|-----|-----|-----|
|Seg U-KAN| BUSI | 65.26 | 78.75 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)|
|Seg U-KAN| GLAS | 87.51 | 93.33 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EunQ9KRf6n1AqCJ40FWZF-QB25GMOoF7hoIwU15fefqFbw?e=m7kXwe)|
|Seg U-KAN| CVC-ClinicDB | 85.61 | 92.19 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/Ekhb3PEmwZZMumSG69wPRRQBymYIi0PFNuLJcVNmmK1fjA?e=5XzVSi)|
The parameter ``--no_kan'' denotes the baseline model that is replaced the KAN layers with MLP layers. Please note: it is reasonable to find occasional inconsistencies in performance, and the average results over multiple runs can reveal a more obvious trend.
|Method| Layer Type | IoU | F1 | Checkpoints |
|-----|------|-----|-----|-----|
|Seg U-KAN (--no_kan)| MLP Layer | 63.49 | 77.07 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EmEH_qokqIFNtP59yU7vY_4Bq4Yc424zuYufwaJuiAGKiw?e=IJ3clx)|
|Seg U-KAN| KAN Layer | 65.26 | 78.75 | [Link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/EjktWkXytkZEgN3EzN2sJKIBfHCeEnJnCnazC68pWCy7kQ?e=4JBLIc)|
## 🎇Medical Image Generation with Diffusion U-KAN
Please refer to [Diffusion_UKAN](./Diffusion_UKAN/README.md)
## 🛒TODO List
- [X] Release code for Seg U-KAN.
- [X] Release code for Diffusion U-KAN.
- [X] Upload the pretrained checkpoints.
## 🎈Acknowledgements
Greatly appreciate the tremendous effort for the following projects!
- [CKAN](https://github.com/AntonioTepsich/Convolutional-KANs)
## 📜Citation
If you find this work helpful for your project,please consider citing the following paper:
```
@article{li2024u,
title={U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation},
author={Li, Chenxin and Liu, Xinyu and Li, Wuyang and Wang, Cheng and Liu, Hengyu and Yuan, Yixuan},
journal={arXiv preprint arXiv:2406.02918},
year={2024}
}
```
MIT License
Copyright (c) 2022 Jeya Maria Jose
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
# U-KAN
U-KAN精度高于Mamba、Transformer构建的Unet模型,本算法基于U-KAN进一步优化精度。
## 论文
`U-KAN Makes Strong Backbone for Medical Image Segmentation and Generation`
- https://arxiv.org/pdf/2406.02918
## 模型结构
KAN已成为MLP的一个有前景的替代方案,U-KAN将新兴算法KAN的优势融入到成熟的U-Net Pipeline 中提升精度,同时增加可解释性,本算法进一步引入部分QKV注意力、KAN ffn等来增强模型的特征提取能力,弥补KAN的不足。
<div align=center>
<img src="./doc/structure.png"/>
</div>
## 算法原理
U-KAN与其它图像分割算法一样,将图片数据送入模型后,依次通过预处理、特征提取,最后经过卷积预测出像素类别实现分割。
<div align=center>
<img src="./doc/algorithm.png"/>
</div>
## 环境配置
```
mv U-KAN-optimize_pytorch U-KAN # 去框架名后缀
# if torch>2.0, modify /usr/local/lib/python3.10/site-packages/timm/models/layers/helpers.py: from torch._six import container_abcs -> import collections.abc as container_abcs
```
### Docker(方法一)
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
# <your IMAGE ID>为以上拉取的docker的镜像ID替换,本镜像为:a4dd5be0ca23
docker run -it --shm-size=32G -v $PWD/U-KAN:/home/U-KAN -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video --name ukan <your IMAGE ID> bash
cd /home/U-KAN
pip install -r Seg_UKAN/requirements.txt # requirements.txt
```
### Dockerfile(方法二)
```
cd U-KAN/docker
docker build --no-cache -t ukan:latest .
docker run --shm-size=32G --name ukan -v /opt/hyhal:/opt/hyhal:ro --privileged=true --device=/dev/kfd --device=/dev/dri/ --group-add video -v $PWD/../../U-KAN:/home/U-KAN -it ukan bash
# 若遇到Dockerfile启动的方式安装环境需要长时间等待,可注释掉里面的pip安装,启动容器后再安装python库:pip install -r Seg_UKAN/requirements.txt。
```
### Anaconda(方法三)
1、关于本项目DCU显卡所需的特殊深度学习库可从光合开发者社区下载安装:
- https://developer.hpccube.com/tool/
```
DTK驱动:dtk24.04.1
python:python3.10
torch:2.1.0
torchvision:0.16.0
```
`Tips:以上dtk驱动、python、torch等DCU相关工具版本需要严格一一对应。`
2、其它非特殊库参照requirements.txt安装
```
pip install -r Seg_UKAN/requirements.txt # requirements.txt
```
## 数据集
`BUSI(Breast Ultrasound Image)`
- https://www.kaggle.com/datasets/aryashah2k/breast-ultrasound-images-dataset
本项目无需下载原始数据集,采用U-KAN作者提供的预处理数据[pre-processed dataset](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155206760_link_cuhk_edu_hk/ErDlT-t0WoBNlKhBlbYfReYB-iviSCmkNRb1GqZ90oYjJA?e=hrPNWD) 即可。
项目中已包含[`busi`](./inputs/busi.zip.zip),解压即可使用,训练数据目录结构如下:
```
Seg_UKAN
├── inputs
│ ├── busi
│ ├── images
│ ├── malignant (1).png
| ├── ...
| ├── masks
│ ├── 0
│ ├── malignant (1)_mask.png
| ├── ...
```
数据集快速下载中心:SCNet AIDatasets ,项目中预处理数据集可从快速下载通道下载:[busi_cvc_glas_preprocessed](http://113.200.138.88:18080/aidatasets/project-dependency/busi_cvc_glas_preprocessed.git)
## 训练
### 单机单卡
```
# 以公开数据集busi为基础进行算法效果优化
cd Seg_UKAN
python train.py --arch UKAN --dataset busi --input_w 256 --input_h 256 --name busi_UKAN --data_dir ./inputs
```
更多资料可参考源项目的[`README_origin`](../README_origin.md)
## result
<div align=center>
<img src="./doc/seg.png"/>
</div>
### 精度
数据集:busi,max epoch为400,训练框架:pytorch。
| 算法 | Dice |
|:---------:|:------:|
| U-KAN | 78.75% |
| U-KAN-optimize | 79.64% |
## 应用场景
### 算法类别
`图像分割`
### 热点应用行业
`医疗,电商,制造,能源`
## 源码仓库及问题反馈
- http://developer.hpccube.com/codes/modelzoo/repvit-optimize_pytorch.git
## 参考资料
- https://github.com/CUHK-AIM-Group/U-KAN.git
- https://github.com/KindXiaoming/pykan.git
- https://kindxiaoming.github.io/pykan/
import torch
from torch import nn
import torch
import torchvision
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.utils import save_image
import torch.nn.functional as F
import os
import matplotlib.pyplot as plt
from utils import *
import timm
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import types
import math
from abc import ABCMeta, abstractmethod
# from mmcv.cnn import ConvModule
from pdb import set_trace as st
from kan import KANLinear, KAN
from torch.nn import init
__all__ = ['UKAN']
class KANLayer(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0., no_kan=False):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.dim = in_features
grid_size=5
spline_order=3
scale_noise=0.1
scale_base=1.0
scale_spline=1.0
base_activation=torch.nn.SiLU
grid_eps=0.02
grid_range=[-1, 1]
if not no_kan:
self.fc1 = KANLinear(
in_features,
hidden_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
self.fc2 = KANLinear(
hidden_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
self.fc3 = KANLinear(
hidden_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
# # TODO
# self.fc4 = KANLinear(
# hidden_features,
# out_features,
# grid_size=grid_size,
# spline_order=spline_order,
# scale_noise=scale_noise,
# scale_base=scale_base,
# scale_spline=scale_spline,
# base_activation=base_activation,
# grid_eps=grid_eps,
# grid_range=grid_range,
# )
else:
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.fc3 = nn.Linear(hidden_features, out_features)
# TODO
# self.fc1 = nn.Linear(in_features, hidden_features)
self.dwconv_1 = DW_bn_relu(hidden_features)
self.dwconv_2 = DW_bn_relu(hidden_features)
self.dwconv_3 = DW_bn_relu(hidden_features)
# # TODO
# self.dwconv_4 = DW_bn_relu(hidden_features)
self.drop = nn.Dropout(drop)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
# pdb.set_trace()
B, N, C = x.shape
x = self.fc1(x.reshape(B*N,C))
x = x.reshape(B,N,C).contiguous()
x = self.dwconv_1(x, H, W)
x = self.fc2(x.reshape(B*N,C))
x = x.reshape(B,N,C).contiguous()
x = self.dwconv_2(x, H, W)
x = self.fc3(x.reshape(B*N,C))
x = x.reshape(B,N,C).contiguous()
x = self.dwconv_3(x, H, W)
# # TODO
# x = x.reshape(B,N,C).contiguous()
# x = self.dwconv_4(x, H, W)
return x
class KANLayer_ffn(nn.Module):
def __init__(self, c, act_layer=nn.GELU, drop=0., no_kan=False):
super().__init__()
in_features = c
out_features = c
hidden_features = 2*c
grid_size=5
spline_order=3
scale_noise=0.1
scale_base=1.0
scale_spline=1.0
base_activation=torch.nn.SiLU
grid_eps=0.02
grid_range=[-1, 1]
if not no_kan:
self.fc1 = KANLinear(
in_features,
hidden_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
self.fc2 = KANLinear(
hidden_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
else:
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
# pdb.set_trace()
B, C, H, W = x.shape
x = self.fc1(x.reshape(B*H*W,C))
x = self.fc2(x)
x = x.reshape(B, C, H, W).contiguous()
return x
class KANBlock(nn.Module):
def __init__(self, dim, drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, no_kan=False):
super().__init__()
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim)
self.layer = KANLayer(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop, no_kan=no_kan)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x, H, W):
x = x + self.drop_path(self.layer(self.norm2(x), H, W))
return x
class DWConv(nn.Module):
def __init__(self, dim=768):
super(DWConv, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = x.flatten(2).transpose(1, 2)
return x
class DW_bn_relu(nn.Module):
def __init__(self, dim=768):
super(DW_bn_relu, self).__init__()
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
self.bn = nn.BatchNorm2d(dim)
self.relu = nn.ReLU()
def forward(self, x, H, W):
B, N, C = x.shape
x = x.transpose(1, 2).view(B, C, H, W)
x = self.dwconv(x)
x = self.bn(x)
x = self.relu(x)
x = x.flatten(2).transpose(1, 2)
return x
def autopad(k, p=None, d=1): # kernel, padding, dilation
"""Pad to 'same' shape outputs."""
if d > 1:
k = d * (k - 1) + 1 if isinstance(k, int) else [d * (x - 1) + 1 for x in k] # actual kernel-size
if p is None:
p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-pad
return p
class Conv(nn.Module):
"""Standard convolution with args(ch_in, ch_out, kernel, stride, padding, groups, dilation, activation)."""
default_act = nn.SiLU() # default activation
def __init__(self, c1, c2, k=1, s=1, p=None, g=1, d=1, act=True):
"""Initialize Conv layer with given arguments including activation."""
super().__init__()
self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p, d), groups=g, dilation=d, bias=False)
self.bn = nn.BatchNorm2d(c2)
self.act = self.default_act if act is True else act if isinstance(act, nn.Module) else nn.Identity()
def forward(self, x):
"""Apply convolution, batch normalization and activation to input tensor."""
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
"""Perform transposed convolution of 2D data."""
return self.act(self.conv(x))
class DWConv2(Conv):
"""Depth-wise convolution."""
def __init__(self, c1, c2, k=1, s=1, p=None, d=1, act=True): # ch_in, ch_out, kernel, stride, dilation, activation
"""Initialize Depth-wise convolution with given parameters."""
super().__init__(c1, c2, k, s, p, g=math.gcd(c1, c2), d=d, act=act)
class PatchEmbed(nn.Module):
""" Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
self.num_patches = self.H * self.W
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
padding=(patch_size[0] // 2, patch_size[1] // 2))
self.norm = nn.LayerNorm(embed_dim)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
elif isinstance(m, nn.Conv2d):
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
fan_out //= m.groups
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
if m.bias is not None:
m.bias.data.zero_()
def forward(self, x):
x = self.proj(x)
_, _, H, W = x.shape
x = x.flatten(2).transpose(1, 2)
x = self.norm(x)
return x, H, W
class ConvLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super(ConvLayer, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class D_ConvLayer(nn.Module):
def __init__(self, in_ch, out_ch):
super(D_ConvLayer, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_ch, in_ch, 3, padding=1),
nn.BatchNorm2d(in_ch),
nn.ReLU(inplace=True),
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True)
)
def forward(self, input):
return self.conv(input)
class Attention(nn.Module):
def __init__(self, dim, num_heads=8,
attn_ratio=0.5):
super().__init__()
self.num_heads = num_heads
self.head_dim = dim // num_heads
self.key_dim = int(self.head_dim * attn_ratio)
self.scale = self.key_dim ** -0.5
nh_kd = nh_kd = self.key_dim * num_heads
h = dim + nh_kd * 2
self.qkv = Conv(dim, h, 1, act=False)
self.proj = Conv(dim, dim, 1, act=False)
self.pe = Conv(dim, dim, 3, 1, g=dim, act=False)
def forward(self, x):
B, C, H, W = x.shape
N = H * W
qkv = self.qkv(x)
q, k, v = qkv.view(B, self.num_heads, self.key_dim*2 + self.head_dim, N).split([self.key_dim, self.key_dim, self.head_dim], dim=2)
attn = (
(q.transpose(-2, -1) @ k) * self.scale
)
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(B, C, H, W) + self.pe(v.reshape(B, C, H, W))
x = self.proj(x)
return x
class PSK(nn.Module):
def __init__(self, c1, c2, e=0.5, e_lambda=1e-4):
super().__init__()
self.e_lambda = e_lambda
self.activaton = nn.Sigmoid()
assert(c1 == c2)
self.c = int(c1 * e)
self.cv1 = Conv(c1, 2 * self.c, 1, 1)
self.cv2 = Conv(2 * self.c, c1, 1)
self.attn = Attention(self.c, attn_ratio=0.5, num_heads=self.c // 64)
self.ffn = nn.Sequential(
Conv(self.c, self.c*2, 1),
Conv(self.c*2, self.c, 1, act=False)
)
self.kanlayer_ffn= KANLayer_ffn(self.c, act_layer=nn.GELU, drop=0, no_kan=False)
def __repr__(self):
s = self.__class__.__name__ + '('
s += ('lambda=%f)' % self.e_lambda)
return s
@staticmethod
def get_module_name():
return "simam"
def forward(self, x):
bs, c, h, w = x.size()
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2,3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2,3], keepdim=True) / n + self.e_lambda)) + 0.5
x = x * self.activaton(y)
a, b = self.cv1(x).split((self.c, self.c), dim=1)
b = b + self.attn(b)
b = b + self.kanlayer_ffn(b)
return self.cv2(torch.cat((a, b), 1))
class UKAN(nn.Module):
def __init__(self, num_classes, input_channels=3, deep_supervision=False, img_size=224, patch_size=16, in_chans=3, embed_dims=[256, 320, 512], no_kan=False,
drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm, depths=[1, 1, 1], **kwargs):
super().__init__()
kan_input_dim = embed_dims[0]
self.encoder1 = ConvLayer(3, kan_input_dim//8)
self.encoder2 = ConvLayer(kan_input_dim//8, kan_input_dim//4)
self.encoder3 = ConvLayer(kan_input_dim//4, kan_input_dim)
self.norm3 = norm_layer(embed_dims[1])
self.norm4 = norm_layer(embed_dims[2])
self.dnorm3 = norm_layer(embed_dims[1])
self.dnorm4 = norm_layer(embed_dims[0])
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.block1 = nn.ModuleList([KANBlock(
dim=embed_dims[1],
drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer
)])
self.block2 = nn.ModuleList([KANBlock(
dim=embed_dims[2],
drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer
)])
self.dblock1 = nn.ModuleList([KANBlock(
dim=embed_dims[1],
drop=drop_rate, drop_path=dpr[0], norm_layer=norm_layer
)])
self.dblock2 = nn.ModuleList([KANBlock(
dim=embed_dims[0],
drop=drop_rate, drop_path=dpr[1], norm_layer=norm_layer
)])
self.patch_embed3 = PatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0], embed_dim=embed_dims[1])
self.patch_embed4 = PatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1], embed_dim=embed_dims[2])
self.mid = PSK(2*embed_dims[0], 2*embed_dims[0])
self.mid2 = PSK(embed_dims[1], embed_dims[1])
self.decoder1 = D_ConvLayer(embed_dims[2], embed_dims[1])
self.decoder2 = D_ConvLayer(embed_dims[1], embed_dims[0])
self.decoder3 = D_ConvLayer(embed_dims[0], embed_dims[0]//4)
self.decoder4 = D_ConvLayer(embed_dims[0]//4, embed_dims[0]//8)
self.decoder5 = D_ConvLayer(embed_dims[0]//8, embed_dims[0]//8)
self.final = nn.Conv2d(embed_dims[0]//8, num_classes, kernel_size=1)
self.soft = nn.Softmax(dim =1)
def forward(self, x):
B = x.shape[0]
### Encoder
### Conv Stage
### Stage 1
out = F.relu(F.max_pool2d(self.encoder1(x), 2, 2))
t1 = out
### Stage 2
out = F.relu(F.max_pool2d(self.encoder2(out), 2, 2))
t2 = out
### Stage 3
out = F.relu(F.max_pool2d(self.encoder3(out), 2, 2))
t3 = out
### Tokenized KAN Stage
### Stage 4
out, H, W = self.patch_embed3(out)
for i, blk in enumerate(self.block1):
out = blk(out, H, W)
out = self.norm3(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
t4 = out
### Bottleneck
out, H, W= self.patch_embed4(out)
for i, blk in enumerate(self.block2):
out = blk(out, H, W)
out = self.norm4(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
out = self.mid(out)
### Stage 4
out = F.relu(F.interpolate(self.decoder1(out), scale_factor=(2,2), mode ='bilinear'))
t4 = self.mid2(t4)
out = torch.add(out, t4)
_, _, H, W = out.shape
out = out.flatten(2).transpose(1,2)
for i, blk in enumerate(self.dblock1):
out = blk(out, H, W)
### Stage 3
out = self.dnorm3(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
out = F.relu(F.interpolate(self.decoder2(out),scale_factor=(2,2),mode ='bilinear'))
out = torch.add(out,t3)
_,_,H,W = out.shape
out = out.flatten(2).transpose(1,2)
for i, blk in enumerate(self.dblock2):
out = blk(out, H, W)
out = self.dnorm4(out)
out = out.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
out = F.relu(F.interpolate(self.decoder3(out),scale_factor=(2,2),mode ='bilinear'))
out = torch.add(out,t2)
out = F.relu(F.interpolate(self.decoder4(out),scale_factor=(2,2),mode ='bilinear'))
out = torch.add(out,t1)
out = F.relu(F.interpolate(self.decoder5(out),scale_factor=(2,2),mode ='bilinear'))
return self.final(out)
# --------------------------------------------------------
# Swin Transformer
# Copyright (c) 2021 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ze Liu
# --------------------------------------------------------'
import os
import yaml
from yacs.config import CfgNode as CN
_C = CN()
# Base config files
_C.BASE = ['']
# -----------------------------------------------------------------------------
# Data settings
# -----------------------------------------------------------------------------
_C.DATA = CN()
# Batch size for a single GPU, could be overwritten by command line argument
_C.DATA.BATCH_SIZE = 1
# Path to dataset, could be overwritten by command line argument
_C.DATA.DATA_PATH = ''
# Dataset name
_C.DATA.DATASET = 'imagenet'
# Input image size
_C.DATA.IMG_SIZE = 256
# Interpolation to resize image (random, bilinear, bicubic)
_C.DATA.INTERPOLATION = 'bicubic'
# Use zipped dataset instead of folder dataset
# could be overwritten by command line argument
_C.DATA.ZIP_MODE = False
# Cache Data in Memory, could be overwritten by command line argument
_C.DATA.CACHE_MODE = 'part'
# Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.
_C.DATA.PIN_MEMORY = True
# Number of data loading threads
_C.DATA.NUM_WORKERS = 8
# -----------------------------------------------------------------------------
# Model settings
# -----------------------------------------------------------------------------
_C.MODEL = CN()
# Model type
_C.MODEL.TYPE = 'swin'
# Model name
_C.MODEL.NAME = 'swin_tiny_patch4_window7_224'
# Checkpoint to resume, could be overwritten by command line argument
_C.MODEL.PRETRAIN_CKPT = './pretrained_ckpt/swin_tiny_patch4_window7_224.pth'
_C.MODEL.RESUME = ''
# Number of classes, overwritten in data preparation
_C.MODEL.NUM_CLASSES = 1000
# Dropout rate
_C.MODEL.DROP_RATE = 0.0
# Drop path rate
_C.MODEL.DROP_PATH_RATE = 0.1
# Label Smoothing
_C.MODEL.LABEL_SMOOTHING = 0.1
# Swin Transformer parameters
_C.MODEL.SWIN = CN()
_C.MODEL.SWIN.PATCH_SIZE = 4
_C.MODEL.SWIN.IN_CHANS = 3
_C.MODEL.SWIN.EMBED_DIM = 96
_C.MODEL.SWIN.DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN.DECODER_DEPTHS = [2, 2, 6, 2]
_C.MODEL.SWIN.NUM_HEADS = [3, 6, 12, 24]
_C.MODEL.SWIN.WINDOW_SIZE = 4
_C.MODEL.SWIN.MLP_RATIO = 4.
_C.MODEL.SWIN.QKV_BIAS = True
_C.MODEL.SWIN.QK_SCALE = False
_C.MODEL.SWIN.APE = False
_C.MODEL.SWIN.PATCH_NORM = True
_C.MODEL.SWIN.FINAL_UPSAMPLE= "expand_first"
# -----------------------------------------------------------------------------
# Training settings
# -----------------------------------------------------------------------------
_C.TRAIN = CN()
_C.TRAIN.START_EPOCH = 0
_C.TRAIN.EPOCHS = 300
_C.TRAIN.WARMUP_EPOCHS = 20
_C.TRAIN.WEIGHT_DECAY = 0.05
_C.TRAIN.BASE_LR = 5e-4
_C.TRAIN.WARMUP_LR = 5e-7
_C.TRAIN.MIN_LR = 5e-6
# Clip gradient norm
_C.TRAIN.CLIP_GRAD = 5.0
# Auto resume from latest checkpoint
_C.TRAIN.AUTO_RESUME = True
# Gradient accumulation steps
# could be overwritten by command line argument
_C.TRAIN.ACCUMULATION_STEPS = 0
# Whether to use gradient checkpointing to save memory
# could be overwritten by command line argument
_C.TRAIN.USE_CHECKPOINT = False
# LR scheduler
_C.TRAIN.LR_SCHEDULER = CN()
_C.TRAIN.LR_SCHEDULER.NAME = 'cosine'
# Epoch interval to decay LR, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_EPOCHS = 30
# LR decay rate, used in StepLRScheduler
_C.TRAIN.LR_SCHEDULER.DECAY_RATE = 0.1
# Optimizer
_C.TRAIN.OPTIMIZER = CN()
_C.TRAIN.OPTIMIZER.NAME = 'adamw'
# Optimizer Epsilon
_C.TRAIN.OPTIMIZER.EPS = 1e-8
# Optimizer Betas
_C.TRAIN.OPTIMIZER.BETAS = (0.9, 0.999)
# SGD momentum
_C.TRAIN.OPTIMIZER.MOMENTUM = 0.9
# -----------------------------------------------------------------------------
# Augmentation settings
# -----------------------------------------------------------------------------
_C.AUG = CN()
# Color jitter factor
_C.AUG.COLOR_JITTER = 0.4
# Use AutoAugment policy. "v0" or "original"
_C.AUG.AUTO_AUGMENT = 'rand-m9-mstd0.5-inc1'
# Random erase prob
_C.AUG.REPROB = 0.25
# Random erase mode
_C.AUG.REMODE = 'pixel'
# Random erase count
_C.AUG.RECOUNT = 1
# Mixup alpha, mixup enabled if > 0
_C.AUG.MIXUP = 0.8
# Cutmix alpha, cutmix enabled if > 0
_C.AUG.CUTMIX = 1.0
# Cutmix min/max ratio, overrides alpha and enables cutmix if set
_C.AUG.CUTMIX_MINMAX = False
# Probability of performing mixup or cutmix when either/both is enabled
_C.AUG.MIXUP_PROB = 1.0
# Probability of switching to cutmix when both mixup and cutmix enabled
_C.AUG.MIXUP_SWITCH_PROB = 0.5
# How to apply mixup/cutmix params. Per "batch", "pair", or "elem"
_C.AUG.MIXUP_MODE = 'batch'
# -----------------------------------------------------------------------------
# Testing settings
# -----------------------------------------------------------------------------
_C.TEST = CN()
# Whether to use center crop when testing
_C.TEST.CROP = True
# -----------------------------------------------------------------------------
# Misc
# -----------------------------------------------------------------------------
# Mixed precision opt level, if O0, no amp is used ('O0', 'O1', 'O2')
# overwritten by command line argument
_C.AMP_OPT_LEVEL = ''
# Path to output folder, overwritten by command line argument
_C.OUTPUT = ''
# Tag of experiment, overwritten by command line argument
_C.TAG = 'default'
# Frequency to save checkpoint
_C.SAVE_FREQ = 1
# Frequency to logging info
_C.PRINT_FREQ = 10
# Fixed random seed
_C.SEED = 0
# Perform evaluation only, overwritten by command line argument
_C.EVAL_MODE = False
# Test throughput only, overwritten by command line argument
_C.THROUGHPUT_MODE = False
# local rank for DistributedDataParallel, given by command line argument
_C.LOCAL_RANK = 0
def _update_config_from_file(config, cfg_file):
config.defrost()
with open(cfg_file, 'r') as f:
yaml_cfg = yaml.load(f, Loader=yaml.FullLoader)
for cfg in yaml_cfg.setdefault('BASE', ['']):
if cfg:
_update_config_from_file(
config, os.path.join(os.path.dirname(cfg_file), cfg)
)
print('=> merge config from {}'.format(cfg_file))
config.merge_from_file(cfg_file)
config.freeze()
def update_config(config, args):
_update_config_from_file(config, args.cfg)
config.defrost()
if args.opts:
config.merge_from_list(args.opts)
# merge from specific arguments
if args.batch_size:
config.DATA.BATCH_SIZE = args.batch_size
if args.zip:
config.DATA.ZIP_MODE = True
if args.cache_mode:
config.DATA.CACHE_MODE = args.cache_mode
if args.resume:
config.MODEL.RESUME = args.resume
if args.accumulation_steps:
config.TRAIN.ACCUMULATION_STEPS = args.accumulation_steps
if args.use_checkpoint:
config.TRAIN.USE_CHECKPOINT = True
if args.amp_opt_level:
config.AMP_OPT_LEVEL = args.amp_opt_level
if args.tag:
config.TAG = args.tag
if args.eval:
config.EVAL_MODE = True
if args.throughput:
config.THROUGHPUT_MODE = True
config.freeze()
def get_config(args):
"""Get a yacs CfgNode object with default values."""
# Return a clone so that the defaults will not be altered
# This is for the "local variable" use pattern
config = _C.clone()
# update_config(config, args)
return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment