Commit b7750bcd authored by yongshk's avatar yongshk
Browse files

add new

parents
Pipeline #574 canceled with stages
*.pyc
data/
__pycache__/
checkpoints/
*.pth
*.jpg
venv/
.idea/
wandb/
This diff is collapsed.
# UNET_pytorch
## 论文
`U-Net: Convolutional Networks for Biomedical Image Segmentation`
- https://arxiv.org/abs/1505.04597
## 模型结构
UNet(全名 U-Net)是一种用于图像分割的卷积神经网络(CNN)架构,UNet 的结构具有 U 形状,因此得名。
![img](https://developer.hpccube.com/codes/yongshk/unet-pytorch/-/raw/main/doc/unet.png)
## 算法原理
U-Net 的核心原理如下:
1. **编码器(Contracting Path)**:U-Net 的编码器由卷积层和池化层组成,用于捕捉图像的特征信息并逐渐减小分辨率。这一部分的任务是将输入图像缩小到一个低分辨率的特征图,同时保留有关图像内容的关键特征。
2. **中间层(Bottleneck)**:在编码器和解码器之间,U-Net 包括一个中间层,通常由卷积层组成,用于进一步提取特征信息。
3. **解码器(Expansive Path)**:U-Net 的解码器包括上采样层和卷积层,用于将特征图恢复到原始输入图像的分辨率。解码器的任务是将高级特征与低级特征相结合,以便生成分割结果。这一部分的结构与编码器相对称。
![img](https://developer.hpccube.com/codes/yongshk/unet-pytorch/-/raw/main/doc/原理.png)
## 环境配置
### Docker(方法一)
此处提供[光源](https://www.sourcefind.cn/#/service-details)拉取docker镜像的地址与使用步骤
```
docker pull image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest
docker run -it --network=host --name=unet --privileged --device=/dev/kfd --device=/dev/dri --ipc=host --shm-size=32G --group-add video --cap-add=SYS_PTRACE --security-opt seccomp=unconfined -u root --ulimit stack=-1:-1 --ulimit memlock=-1:-1 -v /root/unet:/home image.sourcefind.cn:5000/dcu/admin/base/pytorch:1.10.0-centos7.6-dtk-23.04-py37-latest
```
### Anaconda(方法二)
此处提供本地配置、编译的详细步骤,例如:
关于本项目DCU显卡所需的特殊深度学习库可从[光合](https://developer.hpccube.com/tool/)开发者社区下载安装。
```
DTK驱动:dtk23.04
python:python3.7
```
`Tips:以上dtk驱动、python等DCU相关工具版本需要严格一一对应`
其它非深度学习库参照requirements.txt安装:
```
pip install -r requirements.txt
```
## 数据集
`Carvana`
- https://www.kaggle.com/c/carvana-image-masking-challenge/data
此处提供数据预处理脚本的使用方法
```
bash scripts/download_data.sh
```
项目中已提供用于试验训练的迷你数据集,训练数据目录结构如下,用于正常训练的完整数据集请按此目录结构进行制备:
```
── data
│   ├── imgs
│  ├──────fff9b3a5373f_01.jpg
│  ├──────fff9b3a5373f_02.jpg
│   └── masks
│  ├────── fff9b3a5373f_01.gif
│  ├────── fff9b3a5373f_02.gif
```
## 训练
### 单机多卡
```
python -m torch.distributed.launch --nproc_per_node 4 train_ddp.py
```
### 单机单卡
```
python train.py
```
## 推理
```
python predict.py -i image.jpg -o output.jpg
```
## result
![rusult](https://developer.hpccube.com/codes/yongshk/unet-pytorch/-/raw/main/doc/结果.png)
### 精度
测试数据:[test data](https://www.kaggle.com/c/carvana-image-masking-challenge/data),使用的加速卡:Z100L。(采用iou系数)
根据测试结果情况填写表格:
| Unet | 精度 | 速度 |
| :------: | :------: | :------: |
| Carvana | 0.976 | 25.96 |
## 应用场景
### 算法类别
`视觉、图像分割`
### 热点应用行业
`医学图像分析``卫星图像分析``自然图像分割``工业检测`
## 源码仓库及问题反馈
- https://developer.hpccube.com/codes/modelzoo/unet-pytorch
## 参考资料
- https://github.com/milesial/Pytorch-Unet
import torch
import torch.nn.functional as F
from tqdm import tqdm
from utils.dice_score import multiclass_dice_coeff, dice_coeff, multiclass_iou_coeff, iou_coeff
@torch.inference_mode()
def evaluate(net, dataloader, device, amp):
net.eval()
num_val_batches = len(dataloader)
dice_score = 0
# iterate over the validation set
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
for batch in tqdm(dataloader, total=num_val_batches, desc='Validation round', unit='batch', leave=False):
image, mask_true = batch['image'], batch['mask']
# move images and labels to correct device and type
image = image.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
mask_true = mask_true.to(device=device, dtype=torch.long)
# predict the mask
mask_pred = net(image)
if net.n_classes == 1:
assert mask_true.min() >= 0 and mask_true.max() <= 1, 'True mask indices should be in [0, 1]'
mask_pred = (F.sigmoid(mask_pred) > 0.5).float()
# compute the Dice score
# dice_score += dice_coeff(mask_pred, mask_true, reduce_batch_first=False)
dice_score += iou_coeff(mask_pred, mask_true, reduce_batch_first=False)
else:
assert mask_true.min() >= 0 and mask_true.max() < net.n_classes, 'True mask indices should be in [0, n_classes['
# convert to one-hot format
mask_true = F.one_hot(mask_true, net.n_classes).permute(0, 3, 1, 2).float()
mask_pred = F.one_hot(mask_pred.argmax(dim=1), net.n_classes).permute(0, 3, 1, 2).float()
# compute the Dice score, ignoring background
# dice_score += multiclass_dice_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
dice_score += multiclass_iou_coeff(mask_pred[:, 1:], mask_true[:, 1:], reduce_batch_first=False)
net.train()
return dice_score / max(num_val_batches, 1)
import torch
from unet import UNet as _UNet
def unet_carvana(pretrained=False, scale=0.5):
"""
UNet model trained on the Carvana dataset ( https://www.kaggle.com/c/carvana-image-masking-challenge/data ).
Set the scale to 0.5 (50%) when predicting.
"""
net = _UNet(n_channels=3, n_classes=2, bilinear=False)
if pretrained:
if scale == 0.5:
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale0.5_epoch2.pth'
elif scale == 1.0:
checkpoint = 'https://github.com/milesial/Pytorch-UNet/releases/download/v3.0/unet_carvana_scale1.0_epoch2.pth'
else:
raise RuntimeError('Only 0.5 and 1.0 scales are available')
state_dict = torch.hub.load_state_dict_from_url(checkpoint, progress=True)
if 'mask_values' in state_dict:
state_dict.pop('mask_values')
net.load_state_dict(state_dict)
return net
import argparse
import logging
import os
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torchvision import transforms
from utils.data_loading import BasicDataset
from unet import UNet
from utils.utils import plot_img_and_mask
def predict_img(net,
full_img,
device,
scale_factor=1,
out_threshold=0.5):
net.eval()
img = torch.from_numpy(BasicDataset.preprocess(None, full_img, scale_factor, is_mask=False))
img = img.unsqueeze(0)
img = img.to(device=device, dtype=torch.float32)
with torch.no_grad():
output = net(img).cpu()
output = F.interpolate(output, (full_img.size[1], full_img.size[0]), mode='bilinear')
if net.n_classes > 1:
mask = output.argmax(dim=1)
else:
mask = torch.sigmoid(output) > out_threshold
return mask[0].long().squeeze().numpy()
def get_args():
parser = argparse.ArgumentParser(description='Predict masks from input images')
parser.add_argument('--model', '-m', default='MODEL.pth', metavar='FILE',
help='Specify the file in which the model is stored')
parser.add_argument('--input', '-i', metavar='INPUT', nargs='+', help='Filenames of input images', required=True)
parser.add_argument('--output', '-o', metavar='OUTPUT', nargs='+', help='Filenames of output images')
parser.add_argument('--viz', '-v', action='store_true',
help='Visualize the images as they are processed')
parser.add_argument('--no-save', '-n', action='store_true', help='Do not save the output masks')
parser.add_argument('--mask-threshold', '-t', type=float, default=0.5,
help='Minimum probability value to consider a mask pixel white')
parser.add_argument('--scale', '-s', type=float, default=0.5,
help='Scale factor for the input images')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()
def get_output_filenames(args):
def _generate_name(fn):
return f'{os.path.splitext(fn)[0]}_OUT.png'
return args.output or list(map(_generate_name, args.input))
def mask_to_image(mask: np.ndarray, mask_values):
if isinstance(mask_values[0], list):
out = np.zeros((mask.shape[-2], mask.shape[-1], len(mask_values[0])), dtype=np.uint8)
elif mask_values == [0, 1]:
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=bool)
else:
out = np.zeros((mask.shape[-2], mask.shape[-1]), dtype=np.uint8)
if mask.ndim == 3:
mask = np.argmax(mask, axis=0)
for i, v in enumerate(mask_values):
out[mask == i] = v
return Image.fromarray(out)
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
in_files = args.input
out_files = get_output_filenames(args)
net = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Loading model {args.model}')
logging.info(f'Using device {device}')
net.to(device=device)
state_dict = torch.load(args.model, map_location=device)
mask_values = state_dict.pop('mask_values', [0, 1])
net.load_state_dict(state_dict)
logging.info('Model loaded!')
for i, filename in enumerate(in_files):
logging.info(f'Predicting image {filename} ...')
img = Image.open(filename)
mask = predict_img(net=net,
full_img=img,
scale_factor=args.scale,
out_threshold=args.mask_threshold,
device=device)
if not args.no_save:
out_filename = out_files[i]
result = mask_to_image(mask, mask_values)
result.save(out_filename)
logging.info(f'Mask saved to {out_filename}')
if args.viz:
logging.info(f'Visualizing results for image {filename}, close to continue...')
plot_img_and_mask(img, mask)
apex==0.1+f49ddd4.abi0.dtk2304.torch1.13
matplotlib==3.5.3
numpy==1.21.6
Pillow==9.5.0
Pillow==10.0.1
torch==1.13.1+git55d300e.abi0.dtk2304
torchvision==0.14.1+git9134838.abi0.dtk2304.torch1.13
tqdm==4.65.0
#!/bin/bash
if [[ ! -f ~/.kaggle/kaggle.json ]]; then
echo -n "Kaggle username: "
read USERNAME
echo
echo -n "Kaggle API key: "
read APIKEY
mkdir -p ~/.kaggle
echo "{\"username\":\"$USERNAME\",\"key\":\"$APIKEY\"}" > ~/.kaggle/kaggle.json
chmod 600 ~/.kaggle/kaggle.json
fi
pip install kaggle --upgrade
kaggle competitions download -c carvana-image-masking-challenge -f train_hq.zip
unzip train_hq.zip
mv train_hq/* data/imgs/
rm -d train_hq
rm train_hq.zip
kaggle competitions download -c carvana-image-masking-challenge -f train_masks.zip
unzip train_masks.zip
mv train_masks/* data/masks/
rm -d train_masks
rm train_masks.zip
import argparse
import logging
import os
import random
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
import wandb
from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss, iou_loss
dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
dir_checkpoint = Path('./checkpoints/')
def train_model(
model,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 1e-5,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False,
weight_decay: float = 1e-8,
momentum: float = 0.999,
gradient_clipping: float = 1.0,
):
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError, IndexError):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
# 3. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, shuffle=True, **loader_args)
val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
# (Initialize logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(
dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
)
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(model.parameters(),
lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss() if model.n_classes > 1 else nn.BCEWithLogitsLoss()
global_step = 0
# 5. Begin training
for epoch in range(1, epochs + 1):
model.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images, true_masks = batch['image'], batch['mask']
assert images.shape[1] == model.n_channels, \
f'Network has been defined with {model.n_channels} input channels, ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
true_masks = true_masks.to(device=device, dtype=torch.long)
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
masks_pred = model(images)
if model.n_classes == 1:
loss = criterion(masks_pred.squeeze(1), true_masks.float())
# loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
loss += iou_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
else:
loss = criterion(masks_pred, true_masks)
# loss += dice_loss(
# F.softmax(masks_pred, dim=1).float(),
# F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
# multiclass=True
# )
loss += iou_loss(
F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, model.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True
)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round
division_step = (n_train // (5 * batch_size))
if division_step > 0:
if global_step % division_step == 0:
histograms = {}
for tag, value in model.named_parameters():
tag = tag.replace('/', '.')
if not (torch.isinf(value) | torch.isnan(value)).any():
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(model, val_loader, device, amp)
scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score))
try:
experiment.log({
'learning rate': optimizer.param_groups[0]['lr'],
'validation Dice': val_score,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
},
'step': global_step,
'epoch': epoch,
**histograms
})
except:
pass
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
state_dict = model.state_dict()
state_dict['mask_values'] = dataset.mask_values
torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
logging.info(f'Checkpoint {epoch} saved!')
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logging.info(f'Using device {device}')
# Change here to adapt to your data
# n_channels=3 for RGB images
# n_classes is the number of probabilities you want to get per pixel
model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
model = model.to(memory_format=torch.channels_last)
logging.info(f'Network:\n'
f'\t{model.n_channels} input channels\n'
f'\t{model.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
if args.load:
state_dict = torch.load(args.load, map_location=device)
del state_dict['mask_values']
model.load_state_dict(state_dict)
logging.info(f'Model loaded from {args.load}')
model.to(device=device)
try:
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp
)
except torch.cuda.OutOfMemoryError:
logging.error('Detected OutOfMemoryError! '
'Enabling checkpointing to reduce memory usage, but this slows down training. '
'Consider enabling AMP (--amp) for fast and memory efficient training')
torch.cuda.empty_cache()
model.use_checkpointing()
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp
)
import argparse
import logging
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from pathlib import Path
from torch import optim
from torch.utils.data import DataLoader, random_split
from tqdm import tqdm
from evaluate import evaluate
from unet import UNet
from utils.data_loading import BasicDataset, CarvanaDataset
from utils.dice_score import dice_loss, iou_loss
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import wandb
dir_img = Path('./data/imgs/')
dir_mask = Path('./data/masks/')
dir_checkpoint = Path('./checkpoints/')
def train_model(
model,
device,
epochs: int = 5,
batch_size: int = 1,
learning_rate: float = 1e-5,
val_percent: float = 0.1,
save_checkpoint: bool = True,
img_scale: float = 0.5,
amp: bool = False,
weight_decay: float = 1e-8,
momentum: float = 0.999,
gradient_clipping: float = 1.0,
local_rank: int = 0,
):
# 1. Create dataset
try:
dataset = CarvanaDataset(dir_img, dir_mask, img_scale)
except (AssertionError, RuntimeError, IndexError):
dataset = BasicDataset(dir_img, dir_mask, img_scale)
# 2. Split into train / validation partitions
n_val = int(len(dataset) * val_percent)
n_train = len(dataset) - n_val
train_set, val_set = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(0))
# # 3. Create data loaders
# loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
# train_loader = DataLoader(train_set, shuffle=True, **loader_args)
# val_loader = DataLoader(val_set, shuffle=False, drop_last=True, **loader_args)
train_sampler = DistributedSampler(train_set)
val_sampler = DistributedSampler(val_set)
# 4. Create data loaders
loader_args = dict(batch_size=batch_size, num_workers=os.cpu_count(), pin_memory=True)
train_loader = DataLoader(train_set, sampler=train_sampler, **loader_args)
val_loader = DataLoader(val_set, sampler=val_sampler, **loader_args)
# (Initialize logging)
experiment = wandb.init(project='U-Net', resume='allow', anonymous='must')
experiment.config.update(
dict(epochs=epochs, batch_size=batch_size, learning_rate=learning_rate,
val_percent=val_percent, save_checkpoint=save_checkpoint, img_scale=img_scale, amp=amp)
)
logging.info(f'''Starting training:
Epochs: {epochs}
Batch size: {batch_size}
Learning rate: {learning_rate}
Training size: {n_train}
Validation size: {n_val}
Checkpoints: {save_checkpoint}
Device: {device.type}
Images scaling: {img_scale}
Mixed Precision: {amp}
''')
# 4. Set up the optimizer, the loss, the learning rate scheduler and the loss scaling for AMP
optimizer = optim.RMSprop(model.parameters(),
lr=learning_rate, weight_decay=weight_decay, momentum=momentum, foreach=True)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'max', patience=5) # goal: maximize Dice score
grad_scaler = torch.cuda.amp.GradScaler(enabled=amp)
criterion = nn.CrossEntropyLoss().to(local_rank) if model.module.n_classes > 1 else nn.BCEWithLogitsLoss().to(local_rank)
global_step = 0
# 5. Begin training
for epoch in range(1, epochs + 1):
model.train()
epoch_loss = 0
with tqdm(total=n_train, desc=f'Epoch {epoch}/{epochs}', unit='img') as pbar:
for batch in train_loader:
images, true_masks = batch['image'], batch['mask']
assert images.shape[1] == model.module.n_channels, \
f'Network has been defined with {model.module.n_channels} input channels, ' \
f'but loaded images have {images.shape[1]} channels. Please check that ' \
'the images are loaded correctly.'
images = images.to(device=device, dtype=torch.float32, memory_format=torch.channels_last)
true_masks = true_masks.to(device=device, dtype=torch.long)
with torch.autocast(device.type if device.type != 'mps' else 'cpu', enabled=amp):
masks_pred = model(images)
if model.module.n_classes == 1:
loss = criterion(masks_pred.squeeze(1), true_masks.float())
# loss += dice_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
loss += iou_loss(F.sigmoid(masks_pred.squeeze(1)), true_masks.float(), multiclass=False)
else:
loss = criterion(masks_pred, true_masks)
# loss += dice_loss(
# F.softmax(masks_pred, dim=1).float(),
# F.one_hot(true_masks, model.module.n_classes).permute(0, 3, 1, 2).float(),
# multiclass=True
# )
loss += iou_loss(
F.softmax(masks_pred, dim=1).float(),
F.one_hot(true_masks, model.module.n_classes).permute(0, 3, 1, 2).float(),
multiclass=True
)
optimizer.zero_grad(set_to_none=True)
grad_scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), gradient_clipping)
grad_scaler.step(optimizer)
grad_scaler.update()
pbar.update(images.shape[0])
global_step += 1
epoch_loss += loss.item()
experiment.log({
'train loss': loss.item(),
'step': global_step,
'epoch': epoch
})
pbar.set_postfix(**{'loss (batch)': loss.item()})
# Evaluation round
division_step = (n_train // (5 * batch_size))
if division_step > 0:
if global_step % division_step == 0:
histograms = {}
for tag, value in model.named_parameters():
tag = tag.replace('/', '.')
if not (torch.isinf(value) | torch.isnan(value)).any():
histograms['Weights/' + tag] = wandb.Histogram(value.data.cpu())
if not (torch.isinf(value.grad) | torch.isnan(value.grad)).any():
histograms['Gradients/' + tag] = wandb.Histogram(value.grad.data.cpu())
val_score = evaluate(model.module, val_loader, device, amp)
scheduler.step(val_score)
logging.info('Validation Dice score: {}'.format(val_score))
try:
experiment.log({
'learning rate': optimizer.param_groups[0]['lr'],
'validation Dice': val_score,
'images': wandb.Image(images[0].cpu()),
'masks': {
'true': wandb.Image(true_masks[0].float().cpu()),
'pred': wandb.Image(masks_pred.argmax(dim=1)[0].float().cpu()),
},
'step': global_step,
'epoch': epoch,
**histograms
})
except:
pass
if save_checkpoint:
Path(dir_checkpoint).mkdir(parents=True, exist_ok=True)
state_dict = model.module.state_dict()
state_dict['mask_values'] = dataset.mask_values
torch.save(state_dict, str(dir_checkpoint / 'checkpoint_epoch{}.pth'.format(epoch)))
logging.info(f'Checkpoint {epoch} saved!')
def get_args():
parser = argparse.ArgumentParser(description='Train the UNet on images and target masks')
parser.add_argument('--local_rank', type=int, default=-1)
parser.add_argument('--epochs', '-e', metavar='E', type=int, default=5, help='Number of epochs')
parser.add_argument('--batch-size', '-b', dest='batch_size', metavar='B', type=int, default=1, help='Batch size')
parser.add_argument('--learning-rate', '-l', metavar='LR', type=float, default=1e-5,
help='Learning rate', dest='lr')
parser.add_argument('--load', '-f', type=str, default=False, help='Load model from a .pth file')
parser.add_argument('--scale', '-s', type=float, default=0.5, help='Downscaling factor of the images')
parser.add_argument('--validation', '-v', dest='val', type=float, default=10.0,
help='Percent of the data that is used as validation (0-100)')
parser.add_argument('--amp', action='store_true', default=False, help='Use mixed precision')
parser.add_argument('--bilinear', action='store_true', default=False, help='Use bilinear upsampling')
parser.add_argument('--classes', '-c', type=int, default=2, help='Number of classes')
return parser.parse_args()
if __name__ == '__main__':
args = get_args()
local_rank = args.local_rank
logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.cuda.set_device(local_rank)
dist.init_process_group(backend='nccl')
logging.info(f'Using device {device}')
model = UNet(n_channels=3, n_classes=args.classes, bilinear=args.bilinear)
model = model.to(memory_format=torch.channels_last)
logging.info(f'Network:\n'
f'\t{model.n_channels} input channels\n'
f'\t{model.n_classes} output channels (classes)\n'
f'\t{"Bilinear" if model.bilinear else "Transposed conv"} upscaling')
if args.load:
state_dict = torch.load(args.load, map_location=device)
del state_dict['mask_values']
model.load_state_dict(state_dict)
logging.info(f'Model loaded from {args.load}')
model.to(local_rank)
model = DDP(model, device_ids=[local_rank], output_device=local_rank)
try:
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp,
local_rank=local_rank,
)
except torch.cuda.OutOfMemoryError:
logging.error('Detected OutOfMemoryError! '
'Enabling checkpointing to reduce memory usage, but this slows down training. '
'Consider enabling AMP (--amp) for fast and memory efficient training')
torch.cuda.empty_cache()
model.use_checkpointing()
train_model(
model=model,
epochs=args.epochs,
batch_size=args.batch_size,
learning_rate=args.lr,
device=device,
img_scale=args.scale,
val_percent=args.val / 100,
amp=args.amp,
local_rank=local_rank,
)
python -m torch.distributed.launch --nproc_per_node 4 train_ddp.py
其中4为启动进程数(小于等于可用GPU数)
\ No newline at end of file
from .unet_model import UNet
""" Full assembly of the parts to form the complete network """
from .unet_parts import *
class UNet(nn.Module):
def __init__(self, n_channels, n_classes, bilinear=False):
super(UNet, self).__init__()
self.n_channels = n_channels
self.n_classes = n_classes
self.bilinear = bilinear
self.inc = (DoubleConv(n_channels, 64))
self.down1 = (Down(64, 128))
self.down2 = (Down(128, 256))
self.down3 = (Down(256, 512))
factor = 2 if bilinear else 1
self.down4 = (Down(512, 1024 // factor))
self.up1 = (Up(1024, 512 // factor, bilinear))
self.up2 = (Up(512, 256 // factor, bilinear))
self.up3 = (Up(256, 128 // factor, bilinear))
self.up4 = (Up(128, 64, bilinear))
self.outc = (OutConv(64, n_classes))
def forward(self, x):
x1 = self.inc(x)
x2 = self.down1(x1)
x3 = self.down2(x2)
x4 = self.down3(x3)
x5 = self.down4(x4)
x = self.up1(x5, x4)
x = self.up2(x, x3)
x = self.up3(x, x2)
x = self.up4(x, x1)
logits = self.outc(x)
return logits
def use_checkpointing(self):
self.inc = torch.utils.checkpoint(self.inc)
self.down1 = torch.utils.checkpoint(self.down1)
self.down2 = torch.utils.checkpoint(self.down2)
self.down3 = torch.utils.checkpoint(self.down3)
self.down4 = torch.utils.checkpoint(self.down4)
self.up1 = torch.utils.checkpoint(self.up1)
self.up2 = torch.utils.checkpoint(self.up2)
self.up3 = torch.utils.checkpoint(self.up3)
self.up4 = torch.utils.checkpoint(self.up4)
self.outc = torch.utils.checkpoint(self.outc)
\ No newline at end of file
""" Parts of the U-Net model """
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
"""(convolution => [BN] => ReLU) * 2"""
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
"""Downscaling with maxpool then double conv"""
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
"""Upscaling then double conv"""
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
# if bilinear, use the normal convolutions to reduce the number of channels
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
else:
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
# if you have padding issues, see
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class OutConv(nn.Module):
def __init__(self, in_channels, out_channels):
super(OutConv, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, x):
return self.conv(x)
import logging
import numpy as np
import torch
from PIL import Image
from functools import lru_cache
from functools import partial
from itertools import repeat
from multiprocessing import Pool
from os import listdir
from os.path import splitext, isfile, join
from pathlib import Path
from torch.utils.data import Dataset
from tqdm import tqdm
def load_image(filename):
ext = splitext(filename)[1]
if ext == '.npy':
return Image.fromarray(np.load(filename))
elif ext in ['.pt', '.pth']:
return Image.fromarray(torch.load(filename).numpy())
else:
return Image.open(filename)
def unique_mask_values(idx, mask_dir, mask_suffix):
mask_file = list(mask_dir.glob(idx + mask_suffix + '.*'))[0]
mask = np.asarray(load_image(mask_file))
if mask.ndim == 2:
return np.unique(mask)
elif mask.ndim == 3:
mask = mask.reshape(-1, mask.shape[-1])
return np.unique(mask, axis=0)
else:
raise ValueError(f'Loaded masks should have 2 or 3 dimensions, found {mask.ndim}')
class BasicDataset(Dataset):
def __init__(self, images_dir: str, mask_dir: str, scale: float = 1.0, mask_suffix: str = ''):
self.images_dir = Path(images_dir)
self.mask_dir = Path(mask_dir)
assert 0 < scale <= 1, 'Scale must be between 0 and 1'
self.scale = scale
self.mask_suffix = mask_suffix
self.ids = [splitext(file)[0] for file in listdir(images_dir) if isfile(join(images_dir, file)) and not file.startswith('.')]
if not self.ids:
raise RuntimeError(f'No input file found in {images_dir}, make sure you put your images there')
logging.info(f'Creating dataset with {len(self.ids)} examples')
logging.info('Scanning mask files to determine unique values')
with Pool() as p:
unique = list(tqdm(
p.imap(partial(unique_mask_values, mask_dir=self.mask_dir, mask_suffix=self.mask_suffix), self.ids),
total=len(self.ids)
))
self.mask_values = list(sorted(np.unique(np.concatenate(unique), axis=0).tolist()))
logging.info(f'Unique mask values: {self.mask_values}')
def __len__(self):
return len(self.ids)
@staticmethod
def preprocess(mask_values, pil_img, scale, is_mask):
w, h = pil_img.size
newW, newH = int(scale * w), int(scale * h)
assert newW > 0 and newH > 0, 'Scale is too small, resized images would have no pixel'
pil_img = pil_img.resize((newW, newH), resample=Image.NEAREST if is_mask else Image.BICUBIC)
img = np.asarray(pil_img)
if is_mask:
mask = np.zeros((newH, newW), dtype=np.int64)
for i, v in enumerate(mask_values):
if img.ndim == 2:
mask[img == v] = i
else:
mask[(img == v).all(-1)] = i
return mask
else:
if img.ndim == 2:
img = img[np.newaxis, ...]
else:
img = img.transpose((2, 0, 1))
if (img > 1).any():
img = img / 255.0
return img
def __getitem__(self, idx):
name = self.ids[idx]
mask_file = list(self.mask_dir.glob(name + self.mask_suffix + '.*'))
img_file = list(self.images_dir.glob(name + '.*'))
assert len(img_file) == 1, f'Either no image or multiple images found for the ID {name}: {img_file}'
assert len(mask_file) == 1, f'Either no mask or multiple masks found for the ID {name}: {mask_file}'
mask = load_image(mask_file[0])
img = load_image(img_file[0])
assert img.size == mask.size, \
f'Image and mask {name} should be the same size, but are {img.size} and {mask.size}'
img = self.preprocess(self.mask_values, img, self.scale, is_mask=False)
mask = self.preprocess(self.mask_values, mask, self.scale, is_mask=True)
return {
'image': torch.as_tensor(img.copy()).float().contiguous(),
'mask': torch.as_tensor(mask.copy()).long().contiguous()
}
class CarvanaDataset(BasicDataset):
def __init__(self, images_dir, mask_dir, scale=1):
super().__init__(images_dir, mask_dir, scale, mask_suffix='_mask')
import torch
from torch import Tensor
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
# Average of Dice coefficient for all batches, or for a single mask
assert input.size() == target.size()
assert input.dim() == 3 or not reduce_batch_first
sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
inter = 2 * (input * target).sum(dim=sum_dim)
sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
dice = (inter + epsilon) / (sets_sum + epsilon)
return dice.mean()
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
# Average of Dice coefficient for all classes
return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
# Dice loss (objective to minimize) between 0 and 1
fn = multiclass_dice_coeff if multiclass else dice_coeff
return 1 - fn(input, target, reduce_batch_first=True)
def iou_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
# Average of IoU for all batches, or for a single mask
assert input.size() == target.size()
assert input.dim() == 3 or not reduce_batch_first
sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
intersection = (input * target).sum(dim=sum_dim)
union = (input + target).sum(dim=sum_dim) - intersection
union = torch.where(union == 0, intersection, union)
iou = (intersection + epsilon) / (union + epsilon)
return iou.mean()
def multiclass_iou_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
# Average of IoU for all classes
return iou_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
def iou_loss(input: Tensor, target: Tensor, multiclass: bool = False):
# IoU loss (objective to minimize) between 0 and 1
fn = multiclass_iou_coeff if multiclass else iou_coeff
return 1 - fn(input, target, reduce_batch_first=True)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment