Commit 7339f0b0 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
import os
import cv2
import numpy as np
import torch
import torch.utils.data
class Dataset(torch.utils.data.Dataset):
def __init__(self, img_ids, img_dir, mask_dir, img_ext, mask_ext, num_classes, transform=None):
"""
Args:
img_ids (list): Image ids.
img_dir: Image file directory.
mask_dir: Mask file directory.
img_ext (str): Image file extension.
mask_ext (str): Mask file extension.
num_classes (int): Number of classes.
transform (Compose, optional): Compose transforms of albumentations. Defaults to None.
Note:
Make sure to put the files as the following structure:
<dataset name>
├── images
| ├── 0a7e06.jpg
│ ├── 0aab0a.jpg
│ ├── 0b1761.jpg
│ ├── ...
|
└── masks
├── 0
| ├── 0a7e06.png
| ├── 0aab0a.png
| ├── 0b1761.png
| ├── ...
|
├── 1
| ├── 0a7e06.png
| ├── 0aab0a.png
| ├── 0b1761.png
| ├── ...
...
"""
self.img_ids = img_ids
self.img_dir = img_dir
self.mask_dir = mask_dir
self.img_ext = img_ext
self.mask_ext = mask_ext
self.num_classes = num_classes
self.transform = transform
def __len__(self):
return len(self.img_ids)
def __getitem__(self, idx):
img_id = self.img_ids[idx]
img = cv2.imread(os.path.join(self.img_dir, img_id + self.img_ext))
mask = []
for i in range(self.num_classes):
# print(os.path.join(self.mask_dir, str(i),
# img_id + self.mask_ext))
mask.append(cv2.imread(os.path.join(self.mask_dir, str(i),
img_id + self.mask_ext), cv2.IMREAD_GRAYSCALE)[..., None])
mask = np.dstack(mask)
if self.transform is not None:
augmented = self.transform(image=img, mask=mask)
img = augmented['image']
mask = augmented['mask']
img = img.astype('float32') / 255
img = img.transpose(2, 0, 1)
mask = mask.astype('float32') / 255
mask = mask.transpose(2, 0, 1)
if mask.max()<1:
mask[mask>0] = 1.0
return img, mask, {'img_id': img_id}
name: ukan
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=4.5=1_gnu
- ca-certificates=2021.10.26=h06a4308_2
- certifi=2021.5.30=py36h06a4308_0
- ld_impl_linux-64=2.35.1=h7274673_9
- libffi=3.3=he6710b0_2
- libgcc-ng=9.3.0=h5101ec6_17
- libgomp=9.3.0=h5101ec6_17
- libstdcxx-ng=9.3.0=hd4cf53a_17
- ncurses=6.3=h7f8727e_2
- openssl=1.1.1l=h7f8727e_0
- pip=21.2.2=py36h06a4308_0
- python=3.6.13=h12debd9_1
- readline=8.1=h27cfd23_0
- setuptools=58.0.4=py36h06a4308_0
- sqlite=3.36.0=hc218d9a_0
- tk=8.6.11=h1ccaba5_0
- wheel=0.37.0=pyhd3eb1b0_1
- xz=5.2.5=h7b6447c_0
- zlib=1.2.11=h7b6447c_3
- pip:
- addict==2.4.0
- dataclasses==0.8
- mmcv-full==1.2.7
- numpy==1.19.5
- opencv-python==4.5.1.48
- perceptual==0.1
- pillow==8.4.0
- scikit-image==0.17.2
- scipy==1.5.4
- tifffile==2020.9.3
- timm==0.3.2
- torch==1.7.1
- torchvision==0.8.2
- typing-extensions==4.0.0
- yapf==0.31.0
# prefix: /home/jeyamariajose/anaconda3/envs/transweather
import torch
import torch.nn.functional as F
import math
class KANLinear(torch.nn.Module):
def __init__(
self,
in_features,
out_features,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
enable_standalone_scale_spline=True,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KANLinear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
h = (grid_range[1] - grid_range[0]) / grid_size
grid = (
(
torch.arange(-spline_order, grid_size + spline_order + 1) * h
+ grid_range[0]
)
.expand(in_features, -1)
.contiguous()
)
self.register_buffer("grid", grid)
self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
self.spline_weight = torch.nn.Parameter(
torch.Tensor(out_features, in_features, grid_size + spline_order)
)
if enable_standalone_scale_spline:
self.spline_scaler = torch.nn.Parameter(
torch.Tensor(out_features, in_features)
)
self.scale_noise = scale_noise
self.scale_base = scale_base
self.scale_spline = scale_spline
self.enable_standalone_scale_spline = enable_standalone_scale_spline
self.base_activation = base_activation()
self.grid_eps = grid_eps
self.reset_parameters()
def reset_parameters(self):
torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
with torch.no_grad():
noise = (
(
torch.rand(self.grid_size + 1, self.in_features, self.out_features)
- 1 / 2
)
* self.scale_noise
/ self.grid_size
)
self.spline_weight.data.copy_(
(self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
* self.curve2coeff(
self.grid.T[self.spline_order : -self.spline_order],
noise,
)
)
if self.enable_standalone_scale_spline:
# torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
def b_splines(self, x: torch.Tensor):
"""
Compute the B-spline bases for the given input tensor.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
grid: torch.Tensor = (
self.grid
) # (in_features, grid_size + 2 * spline_order + 1)
x = x.unsqueeze(-1)
bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
for k in range(1, self.spline_order + 1):
bases = (
(x - grid[:, : -(k + 1)])
/ (grid[:, k:-1] - grid[:, : -(k + 1)])
* bases[:, :, :-1]
) + (
(grid[:, k + 1 :] - x)
/ (grid[:, k + 1 :] - grid[:, 1:(-k)])
* bases[:, :, 1:]
)
assert bases.size() == (
x.size(0),
self.in_features,
self.grid_size + self.spline_order,
)
return bases.contiguous()
def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
"""
Compute the coefficients of the curve that interpolates the given points.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
Returns:
torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
"""
assert x.dim() == 2 and x.size(1) == self.in_features
assert y.size() == (x.size(0), self.in_features, self.out_features)
A = self.b_splines(x).transpose(
0, 1
) # (in_features, batch_size, grid_size + spline_order)
B = y.transpose(0, 1) # (in_features, batch_size, out_features)
solution = torch.linalg.lstsq(
A, B
).solution # (in_features, grid_size + spline_order, out_features)
result = solution.permute(
2, 0, 1
) # (out_features, in_features, grid_size + spline_order)
assert result.size() == (
self.out_features,
self.in_features,
self.grid_size + self.spline_order,
)
return result.contiguous()
@property
def scaled_spline_weight(self):
return self.spline_weight * (
self.spline_scaler.unsqueeze(-1)
if self.enable_standalone_scale_spline
else 1.0
)
def forward(self, x: torch.Tensor):
assert x.dim() == 2 and x.size(1) == self.in_features
base_output = F.linear(self.base_activation(x), self.base_weight)
spline_output = F.linear(
self.b_splines(x).view(x.size(0), -1),
self.scaled_spline_weight.view(self.out_features, -1),
)
return base_output + spline_output
@torch.no_grad()
def update_grid(self, x: torch.Tensor, margin=0.01):
assert x.dim() == 2 and x.size(1) == self.in_features
batch = x.size(0)
splines = self.b_splines(x) # (batch, in, coeff)
splines = splines.permute(1, 0, 2) # (in, batch, coeff)
orig_coeff = self.scaled_spline_weight # (out, in, coeff)
orig_coeff = orig_coeff.permute(1, 2, 0) # (in, coeff, out)
unreduced_spline_output = torch.bmm(splines, orig_coeff) # (in, batch, out)
unreduced_spline_output = unreduced_spline_output.permute(
1, 0, 2
) # (batch, in, out)
# sort each channel individually to collect data distribution
x_sorted = torch.sort(x, dim=0)[0]
grid_adaptive = x_sorted[
torch.linspace(
0, batch - 1, self.grid_size + 1, dtype=torch.int64, device=x.device
)
]
uniform_step = (x_sorted[-1] - x_sorted[0] + 2 * margin) / self.grid_size
grid_uniform = (
torch.arange(
self.grid_size + 1, dtype=torch.float32, device=x.device
).unsqueeze(1)
* uniform_step
+ x_sorted[0]
- margin
)
grid = self.grid_eps * grid_uniform + (1 - self.grid_eps) * grid_adaptive
grid = torch.concatenate(
[
grid[:1]
- uniform_step
* torch.arange(self.spline_order, 0, -1, device=x.device).unsqueeze(1),
grid,
grid[-1:]
+ uniform_step
* torch.arange(1, self.spline_order + 1, device=x.device).unsqueeze(1),
],
dim=0,
)
self.grid.copy_(grid.T)
self.spline_weight.data.copy_(self.curve2coeff(x, unreduced_spline_output))
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
"""
Compute the regularization loss.
This is a dumb simulation of the original L1 regularization as stated in the
paper, since the original one requires computing absolutes and entropy from the
expanded (batch, in_features, out_features) intermediate tensor, which is hidden
behind the F.linear function if we want an memory efficient implementation.
The L1 regularization is now computed as mean absolute value of the spline
weights. The authors implementation also includes this term in addition to the
sample-based regularization.
"""
l1_fake = self.spline_weight.abs().mean(-1)
regularization_loss_activation = l1_fake.sum()
p = l1_fake / regularization_loss_activation
regularization_loss_entropy = -torch.sum(p * p.log())
return (
regularize_activation * regularization_loss_activation
+ regularize_entropy * regularization_loss_entropy
)
class KAN(torch.nn.Module):
def __init__(
self,
layers_hidden,
grid_size=5,
spline_order=3,
scale_noise=0.1,
scale_base=1.0,
scale_spline=1.0,
base_activation=torch.nn.SiLU,
grid_eps=0.02,
grid_range=[-1, 1],
):
super(KAN, self).__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.layers = torch.nn.ModuleList()
for in_features, out_features in zip(layers_hidden, layers_hidden[1:]):
self.layers.append(
KANLinear(
in_features,
out_features,
grid_size=grid_size,
spline_order=spline_order,
scale_noise=scale_noise,
scale_base=scale_base,
scale_spline=scale_spline,
base_activation=base_activation,
grid_eps=grid_eps,
grid_range=grid_range,
)
)
def forward(self, x: torch.Tensor, update_grid=False):
for layer in self.layers:
if update_grid:
layer.update_grid(x)
x = layer(x)
return x
def regularization_loss(self, regularize_activation=1.0, regularize_entropy=1.0):
return sum(
layer.regularization_loss(regularize_activation, regularize_entropy)
for layer in self.layers
)
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from LovaszSoftmax.pytorch.lovasz_losses import lovasz_hinge
except ImportError:
pass
__all__ = ['BCEDiceLoss', 'LovaszHingeLoss']
class BCEDiceLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
bce = F.binary_cross_entropy_with_logits(input, target)
smooth = 1e-5
input = torch.sigmoid(input)
num = target.size(0)
input = input.view(num, -1)
target = target.view(num, -1)
intersection = (input * target)
dice = (2. * intersection.sum(1) + smooth) / (input.sum(1) + target.sum(1) + smooth)
dice = 1 - dice.sum() / num
return 0.5 * bce + dice
class LovaszHingeLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, input, target):
input = input.squeeze(1)
target = target.squeeze(1)
loss = lovasz_hinge(input, target, per_image=True)
return loss
import numpy as np
import torch
import torch.nn.functional as F
from medpy.metric.binary import jc, dc, hd, hd95, recall, specificity, precision
def iou_score(output, target):
smooth = 1e-5
if torch.is_tensor(output):
output = torch.sigmoid(output).data.cpu().numpy()
if torch.is_tensor(target):
target = target.data.cpu().numpy()
output_ = output > 0.5
target_ = target > 0.5
intersection = (output_ & target_).sum()
union = (output_ | target_).sum()
iou = (intersection + smooth) / (union + smooth)
dice = (2* iou) / (iou+1)
try:
hd95_ = hd95(output_, target_)
except:
hd95_ = 0
return iou, dice, hd95_
def dice_coef(output, target):
smooth = 1e-5
output = torch.sigmoid(output).view(-1).data.cpu().numpy()
target = target.view(-1).data.cpu().numpy()
intersection = (output * target).sum()
return (2. * intersection + smooth) / \
(output.sum() + target.sum() + smooth)
def indicators(output, target):
if torch.is_tensor(output):
output = torch.sigmoid(output).data.cpu().numpy()
if torch.is_tensor(target):
target = target.data.cpu().numpy()
output_ = output > 0.5
target_ = target > 0.5
iou_ = jc(output_, target_)
dice_ = dc(output_, target_)
hd_ = hd(output_, target_)
hd95_ = hd95(output_, target_)
recall_ = recall(output_, target_)
specificity_ = specificity(output_, target_)
precision_ = precision(output_, target_)
return iou_, dice_, hd_, hd95_, recall_, specificity_, precision_
# 模型编码
modelCode=697
# 模型名称
modelName=u-kan-optimize_pytorch
# 模型描述
modelDescription=U-KAN精度高于Mamba、Transformer构建的Unet模型,本算法基于U-KAN进一步优化精度。
# 应用场景
appScenario=训练,图像分割,医疗,电商,制造,能源
# 框架类型
frameType=pytorch
addict==2.4.0
# dataclasses
# pandas
pyyaml
albumentations
tqdm
tensorboardX
# mmcv-full==1.2.7
# numpy
opencv-python
perceptual==0.1
pillow==8.4.0
pyproject.toml
# scikit-image==0.17.2
# scipy==1.5.4
tifffile==2020.9.3
timm==0.3.2
# typing-extensions==4.0.0
yapf==0.31.0
medpy
# pip install torch==1.13.0+cu116 torchvision==0.14.0+cu116 torchaudio==0.13.0 --extra-index-url https://download.pytorch.org/whl/cu116
# if torch>2.0,modify /usr/local/lib/python3.10/site-packages/timm/models/layers/helpers.py: from torch._six import container_abcs -> import collections.abc as container_abcs
dataset=busi
input_size=256
python train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN --data_dir [YOUR_DATA_DIR]
python val.py --name ${dataset}_UKAN
dataset=glas
input_size=512
python train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN --data_dir [YOUR_DATA_DIR]
python val.py --name ${dataset}_UKAN
dataset=cvc
input_size=256
python train.py --arch UKAN --dataset ${dataset} --input_w ${input_size} --input_h ${input_size} --name ${dataset}_UKAN --data_dir [YOUR_DATA_DIR]
python val.py --name ${dataset}_UKAN
import argparse
import os
from collections import OrderedDict
from glob import glob
import random
import numpy as np
import pandas as pd
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import yaml
from albumentations.augmentations import transforms
from albumentations.augmentations import geometric
from albumentations.core.composition import Compose, OneOf
from sklearn.model_selection import train_test_split
from torch.optim import lr_scheduler
from tqdm import tqdm
from albumentations import RandomRotate90, Resize
import archs
import losses
from dataset import Dataset
from metrics import iou_score, indicators
from utils import AverageMeter, str2bool
from tensorboardX import SummaryWriter
import shutil
import os
import subprocess
from pdb import set_trace as st
ARCH_NAMES = archs.__all__
LOSS_NAMES = losses.__all__
LOSS_NAMES.append('BCEWithLogitsLoss')
def list_type(s):
str_list = s.split(',')
int_list = [int(a) for a in str_list]
return int_list
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None,
help='model name: (default: arch+timestamp)')
parser.add_argument('--epochs', default=400, type=int, metavar='N',
help='number of total epochs to run')
parser.add_argument('-b', '--batch_size', default=8, type=int,
metavar='N', help='mini-batch size (default: 16)')
parser.add_argument('--dataseed', default=2981, type=int,
help='')
# model
parser.add_argument('--arch', '-a', metavar='ARCH', default='UKAN')
parser.add_argument('--deep_supervision', default=False, type=str2bool)
parser.add_argument('--input_channels', default=3, type=int,
help='input channels')
parser.add_argument('--num_classes', default=1, type=int,
help='number of classes')
parser.add_argument('--input_w', default=256, type=int,
help='image width')
parser.add_argument('--input_h', default=256, type=int,
help='image height')
parser.add_argument('--input_list', type=list_type, default=[128, 160, 256])
# loss
parser.add_argument('--loss', default='BCEDiceLoss',
choices=LOSS_NAMES,
help='loss: ' +
' | '.join(LOSS_NAMES) +
' (default: BCEDiceLoss)')
# dataset
parser.add_argument('--dataset', default='busi', help='dataset name')
parser.add_argument('--data_dir', default='inputs', help='dataset dir')
parser.add_argument('--output_dir', default='outputs', help='ouput dir')
# optimizer
parser.add_argument('--optimizer', default='Adam',
choices=['Adam', 'SGD'],
help='loss: ' +
' | '.join(['Adam', 'SGD']) +
' (default: Adam)')
parser.add_argument('--lr', '--learning_rate', default=1e-4, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--momentum', default=0.9, type=float,
help='momentum')
parser.add_argument('--weight_decay', default=1e-4, type=float,
help='weight decay')
parser.add_argument('--nesterov', default=False, type=str2bool,
help='nesterov')
parser.add_argument('--kan_lr', default=1e-2, type=float,
metavar='LR', help='initial learning rate')
parser.add_argument('--kan_weight_decay', default=1e-4, type=float,
help='weight decay')
# scheduler
parser.add_argument('--scheduler', default='CosineAnnealingLR',
choices=['CosineAnnealingLR', 'ReduceLROnPlateau', 'MultiStepLR', 'ConstantLR'])
parser.add_argument('--min_lr', default=1e-5, type=float,
help='minimum learning rate')
parser.add_argument('--factor', default=0.1, type=float)
parser.add_argument('--patience', default=2, type=int)
parser.add_argument('--milestones', default='1,2', type=str)
parser.add_argument('--gamma', default=2/3, type=float)
parser.add_argument('--early_stopping', default=-1, type=int,
metavar='N', help='early stopping (default: -1)')
parser.add_argument('--cfg', type=str, metavar="FILE", help='path to config file', )
parser.add_argument('--num_workers', default=4, type=int)
parser.add_argument('--no_kan', action='store_true')
config = parser.parse_args()
return config
def train(config, train_loader, model, criterion, optimizer):
avg_meters = {'loss': AverageMeter(),
'iou': AverageMeter()}
model.train()
pbar = tqdm(total=len(train_loader))
for input, target, _ in train_loader:
input = input.cuda()
target = target.cuda()
# compute output
if config['deep_supervision']:
outputs = model(input)
loss = 0
for output in outputs:
loss += criterion(output, target)
loss /= len(outputs)
iou, dice, _ = iou_score(outputs[-1], target)
iou_, dice_, hd_, hd95_, recall_, specificity_, precision_ = indicators(outputs[-1], target)
else:
output = model(input)
loss = criterion(output, target)
iou, dice, _ = iou_score(output, target)
iou_, dice_, hd_, hd95_, recall_, specificity_, precision_ = indicators(output, target)
# compute gradient and do optimizing step
optimizer.zero_grad()
loss.backward()
optimizer.step()
avg_meters['loss'].update(loss.item(), input.size(0))
avg_meters['iou'].update(iou, input.size(0))
postfix = OrderedDict([
('loss', avg_meters['loss'].avg),
('iou', avg_meters['iou'].avg),
])
pbar.set_postfix(postfix)
pbar.update(1)
pbar.close()
return OrderedDict([('loss', avg_meters['loss'].avg),
('iou', avg_meters['iou'].avg)])
def validate(config, val_loader, model, criterion):
avg_meters = {'loss': AverageMeter(),
'iou': AverageMeter(),
'dice': AverageMeter()}
# switch to evaluate mode
model.eval()
with torch.no_grad():
pbar = tqdm(total=len(val_loader))
for input, target, _ in val_loader:
input = input.cuda()
target = target.cuda()
# compute output
if config['deep_supervision']:
outputs = model(input)
loss = 0
for output in outputs:
loss += criterion(output, target)
loss /= len(outputs)
iou, dice, _ = iou_score(outputs[-1], target)
else:
output = model(input)
loss = criterion(output, target)
iou, dice, _ = iou_score(output, target)
avg_meters['loss'].update(loss.item(), input.size(0))
avg_meters['iou'].update(iou, input.size(0))
avg_meters['dice'].update(dice, input.size(0))
postfix = OrderedDict([
('loss', avg_meters['loss'].avg),
('iou', avg_meters['iou'].avg),
('dice', avg_meters['dice'].avg)
])
pbar.set_postfix(postfix)
pbar.update(1)
pbar.close()
return OrderedDict([('loss', avg_meters['loss'].avg),
('iou', avg_meters['iou'].avg),
('dice', avg_meters['dice'].avg)])
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def main():
seed_torch()
config = vars(parse_args())
exp_name = config.get('name')
output_dir = config.get('output_dir')
my_writer = SummaryWriter(f'{output_dir}/{exp_name}')
if config['name'] is None:
if config['deep_supervision']:
config['name'] = '%s_%s_wDS' % (config['dataset'], config['arch'])
else:
config['name'] = '%s_%s_woDS' % (config['dataset'], config['arch'])
os.makedirs(f'{output_dir}/{exp_name}', exist_ok=True)
print('-' * 20)
for key in config:
print('%s: %s' % (key, config[key]))
print('-' * 20)
with open(f'{output_dir}/{exp_name}/config.yml', 'w') as f:
yaml.dump(config, f)
# define loss function (criterion)
if config['loss'] == 'BCEWithLogitsLoss':
criterion = nn.BCEWithLogitsLoss().cuda()
else:
criterion = losses.__dict__[config['loss']]().cuda()
cudnn.benchmark = True
# create model
model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'], embed_dims=config['input_list'], no_kan=config['no_kan'])
model = model.cuda()
param_groups = []
kan_fc_params = []
other_params = []
for name, param in model.named_parameters():
# print(name, "=>", param.shape)
if 'layer' in name.lower() and 'fc' in name.lower(): # higher lr for kan layers
# kan_fc_params.append(name)
param_groups.append({'params': param, 'lr': config['kan_lr'], 'weight_decay': config['kan_weight_decay']})
else:
# other_params.append(name)
param_groups.append({'params': param, 'lr': config['lr'], 'weight_decay': config['weight_decay']})
# st()
if config['optimizer'] == 'Adam':
optimizer = optim.Adam(param_groups)
elif config['optimizer'] == 'SGD':
optimizer = optim.SGD(params, lr=config['lr'], momentum=config['momentum'], nesterov=config['nesterov'], weight_decay=config['weight_decay'])
else:
raise NotImplementedError
if config['scheduler'] == 'CosineAnnealingLR':
scheduler = lr_scheduler.CosineAnnealingLR(
optimizer, T_max=config['epochs'], eta_min=config['min_lr'])
elif config['scheduler'] == 'ReduceLROnPlateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, factor=config['factor'], patience=config['patience'], verbose=1, min_lr=config['min_lr'])
elif config['scheduler'] == 'MultiStepLR':
scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=[int(e) for e in config['milestones'].split(',')], gamma=config['gamma'])
elif config['scheduler'] == 'ConstantLR':
scheduler = None
else:
raise NotImplementedError
shutil.copy2('train.py', f'{output_dir}/{exp_name}/')
shutil.copy2('archs.py', f'{output_dir}/{exp_name}/')
dataset_name = config['dataset']
img_ext = '.png'
if dataset_name == 'busi':
mask_ext = '_mask.png'
elif dataset_name == 'glas':
mask_ext = '.png'
elif dataset_name == 'cvc':
mask_ext = '.png'
# Data loading code
img_ids = sorted(glob(os.path.join(config['data_dir'], config['dataset'], 'images', '*' + img_ext)))
img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
train_img_ids, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=config['dataseed'])
train_transform = Compose([
RandomRotate90(),
geometric.transforms.Flip(),
Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
])
val_transform = Compose([
Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
])
train_dataset = Dataset(
img_ids=train_img_ids,
img_dir=os.path.join(config['data_dir'], config['dataset'], 'images'),
mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),
img_ext=img_ext,
mask_ext=mask_ext,
num_classes=config['num_classes'],
transform=train_transform)
val_dataset = Dataset(
img_ids=val_img_ids,
img_dir=os.path.join(config['data_dir'] ,config['dataset'], 'images'),
mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),
img_ext=img_ext,
mask_ext=mask_ext,
num_classes=config['num_classes'],
transform=val_transform)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config['batch_size'],
shuffle=True,
num_workers=config['num_workers'],
drop_last=True)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers'],
drop_last=False)
log = OrderedDict([
('epoch', []),
('lr', []),
('loss', []),
('iou', []),
('val_loss', []),
('val_iou', []),
('val_dice', []),
])
best_iou = 0
best_dice= 0
trigger = 0
for epoch in range(config['epochs']):
print('Epoch [%d/%d]' % (epoch, config['epochs']))
# train for one epoch
train_log = train(config, train_loader, model, criterion, optimizer)
# evaluate on validation set
val_log = validate(config, val_loader, model, criterion)
if config['scheduler'] == 'CosineAnnealingLR':
scheduler.step()
elif config['scheduler'] == 'ReduceLROnPlateau':
scheduler.step(val_log['loss'])
print('loss %.4f - iou %.4f - val_loss %.4f - val_iou %.4f'
% (train_log['loss'], train_log['iou'], val_log['loss'], val_log['iou']))
log['epoch'].append(epoch)
log['lr'].append(config['lr'])
log['loss'].append(train_log['loss'])
log['iou'].append(train_log['iou'])
log['val_loss'].append(val_log['loss'])
log['val_iou'].append(val_log['iou'])
log['val_dice'].append(val_log['dice'])
pd.DataFrame(log).to_csv(f'{output_dir}/{exp_name}/log.csv', index=False)
my_writer.add_scalar('train/loss', train_log['loss'], global_step=epoch)
my_writer.add_scalar('train/iou', train_log['iou'], global_step=epoch)
my_writer.add_scalar('val/loss', val_log['loss'], global_step=epoch)
my_writer.add_scalar('val/iou', val_log['iou'], global_step=epoch)
my_writer.add_scalar('val/dice', val_log['dice'], global_step=epoch)
my_writer.add_scalar('val/best_iou_value', best_iou, global_step=epoch)
my_writer.add_scalar('val/best_dice_value', best_dice, global_step=epoch)
trigger += 1
if val_log['iou'] > best_iou:
torch.save(model.state_dict(), f'{output_dir}/{exp_name}/model.pth')
best_iou = val_log['iou']
best_dice = val_log['dice']
print("=> saved best model")
print('IoU: %.4f' % best_iou)
print('Dice: %.4f' % best_dice)
trigger = 0
# early stopping
if config['early_stopping'] >= 0 and trigger >= config['early_stopping']:
print("=> early stopping")
break
torch.cuda.empty_cache()
if __name__ == '__main__':
main()
python train.py --arch UKAN --dataset busi --input_w 256 --input_h 256 --name busi_UKAN --data_dir ./inputs
import argparse
import torch.nn as nn
class qkv_transform(nn.Conv1d):
"""Conv1d for qkv_transform"""
def str2bool(v):
if v.lower() in ['true', 1]:
return True
elif v.lower() in ['false', 0]:
return False
else:
raise argparse.ArgumentTypeError('Boolean value expected.')
def count_params(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
class AverageMeter(object):
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
#! /data/cxli/miniconda3/envs/th200/bin/python
import argparse
import os
from glob import glob
import random
import numpy as np
import cv2
import torch
import torch.backends.cudnn as cudnn
import yaml
from albumentations.augmentations import transforms
from albumentations.core.composition import Compose
from sklearn.model_selection import train_test_split
from tqdm import tqdm
from collections import OrderedDict
import archs
from dataset import Dataset
from metrics import iou_score
from utils import AverageMeter
from albumentations import RandomRotate90,Resize
import time
from PIL import Image
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--name', default=None, help='model name')
parser.add_argument('--output_dir', default='outputs', help='ouput dir')
args = parser.parse_args()
return args
def seed_torch(seed=1029):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def main():
seed_torch()
args = parse_args()
with open(f'{args.output_dir}/{args.name}/config.yml', 'r') as f:
config = yaml.load(f, Loader=yaml.FullLoader)
print('-'*20)
for key in config.keys():
print('%s: %s' % (key, str(config[key])))
print('-'*20)
cudnn.benchmark = True
model = archs.__dict__[config['arch']](config['num_classes'], config['input_channels'], config['deep_supervision'], embed_dims=config['input_list'])
model = model.cuda()
dataset_name = config['dataset']
img_ext = '.png'
if dataset_name == 'busi':
mask_ext = '_mask.png'
elif dataset_name == 'glas':
mask_ext = '.png'
elif dataset_name == 'cvc':
mask_ext = '.png'
# Data loading code
img_ids = sorted(glob(os.path.join(config['data_dir'], config['dataset'], 'images', '*' + img_ext)))
# img_ids.sort()
img_ids = [os.path.splitext(os.path.basename(p))[0] for p in img_ids]
_, val_img_ids = train_test_split(img_ids, test_size=0.2, random_state=config['dataseed'])
ckpt = torch.load(f'{args.output_dir}/{args.name}/model.pth')
try:
model.load_state_dict(ckpt)
except:
print("Pretrained model keys:", ckpt.keys())
print("Current model keys:", model.state_dict().keys())
pretrained_dict = {k: v for k, v in ckpt.items() if k in model.state_dict()}
current_dict = model.state_dict()
diff_keys = set(current_dict.keys()) - set(pretrained_dict.keys())
print("Difference in model keys:")
for key in diff_keys:
print(f"Key: {key}")
model.load_state_dict(ckpt, strict=False)
model.eval()
val_transform = Compose([
Resize(config['input_h'], config['input_w']),
transforms.Normalize(),
])
val_dataset = Dataset(
img_ids=val_img_ids,
img_dir=os.path.join(config['data_dir'], config['dataset'], 'images'),
mask_dir=os.path.join(config['data_dir'], config['dataset'], 'masks'),
img_ext=img_ext,
mask_ext=mask_ext,
num_classes=config['num_classes'],
transform=val_transform)
val_loader = torch.utils.data.DataLoader(
val_dataset,
batch_size=config['batch_size'],
shuffle=False,
num_workers=config['num_workers'],
drop_last=False)
iou_avg_meter = AverageMeter()
dice_avg_meter = AverageMeter()
hd95_avg_meter = AverageMeter()
with torch.no_grad():
for input, target, meta in tqdm(val_loader, total=len(val_loader)):
input = input.cuda()
target = target.cuda()
model = model.cuda()
# compute output
output = model(input)
iou, dice, hd95_ = iou_score(output, target)
iou_avg_meter.update(iou, input.size(0))
dice_avg_meter.update(dice, input.size(0))
hd95_avg_meter.update(hd95_, input.size(0))
output = torch.sigmoid(output).cpu().numpy()
output[output>=0.5]=1
output[output<0.5]=0
os.makedirs(os.path.join(args.output_dir, config['name'], 'out_val'), exist_ok=True)
for pred, img_id in zip(output, meta['img_id']):
pred_np = pred[0].astype(np.uint8)
pred_np = pred_np * 255
img = Image.fromarray(pred_np, 'L')
img.save(os.path.join(args.output_dir, config['name'], 'out_val/{}.jpg'.format(img_id)))
print(config['name'])
print('IoU: %.4f' % iou_avg_meter.avg)
print('Dice: %.4f' % dice_avg_meter.avg)
print('HD95: %.4f' % hd95_avg_meter.avg)
if __name__ == '__main__':
main()
python val.py --name busi_UKAN --output_dir ckpt
FROM image.sourcefind.cn:5000/dcu/admin/base/pytorch:2.1.0-ubuntu20.04-dtk24.04.1-py3.10
ENV DEBIAN_FRONTEND=noninteractive
# RUN yum update && yum install -y git cmake wget build-essential
# RUN source /opt/dtk-24.04.1/env.sh
# # 安装pip相关依赖
COPY requirements.txt requirements.txt
RUN pip3 install -r requirements.txt -i http://mirrors.aliyun.com/pypi/simple/ --trusted-host mirrors.aliyun.com
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment