"...community/stable_diffusion_controlnet_inpaint_img2img.py" did not exist on "07c0fe4b87a07fc1b42bac738f013b78833559ae"
Commit 1e2486af authored by sunxx1's avatar sunxx1
Browse files

添加inception_v3测试代码

parents
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from .dataset import PlainDataset
import os
def get_abs_path(rel):
return os.path.join(os.getcwd(), rel)
def build_augmentation(cfg):
compose_list = []
if cfg.random_resize_crop:
compose_list.append(
transforms.RandomResizedCrop(cfg.random_resize_crop))
if cfg.resize:
compose_list.append(transforms.Resize(cfg.resize))
if cfg.random_crop:
compose_list.append(transforms.RandomCrop(cfg.random_crop))
if cfg.center_crop:
compose_list.append(transforms.CenterCrop(cfg.center_crop))
if cfg.mirror:
compose_list.append(transforms.RandomHorizontalFlip())
if cfg.colorjitter:
compose_list.append(transforms.ColorJitter(*cfg.colorjitter))
compose_list.append(transforms.ToTensor())
data_normalize = transforms.Normalize(mean=cfg.get('mean',
[0.485, 0.456, 0.406]),
std=cfg.get('std',
[0.229, 0.224, 0.225]))
compose_list.append(data_normalize)
return transforms.Compose(compose_list)
def build_dataloader(cfg, world_size):
train_aug = build_augmentation(cfg.train)
test_aug = build_augmentation(cfg.test)
ds_cls = PlainDataset
train_dataset = ds_cls(cfg.train.image_dir, cfg.train.meta_file, train_aug)
train_sampler = DistributedSampler(train_dataset)
train_loader = DataLoader(train_dataset,
batch_size=cfg.batch_size,
shuffle=(train_sampler is None),
num_workers=cfg.workers,
pin_memory=True,
sampler=train_sampler)
test_dataset = ds_cls(cfg.test.image_dir, cfg.test.meta_file, test_aug)
test_sampler = DistributedSampler(test_dataset)
test_loader = DataLoader(test_dataset,
batch_size=cfg.batch_size,
shuffle=(test_sampler is None),
num_workers=cfg.workers,
pin_memory=True,
sampler=test_sampler,
drop_last=False)
return train_loader, train_sampler, test_loader, test_sampler
from PIL import Image
from torch.utils.data import Dataset
class PlainDataset(Dataset):
r"""
Dataset using memcached to read data.
Arguments
* root (string): Root directory of the Dataset.
* meta_file (string): The meta file of the Dataset. Each line has a image path
and a label. Eg: ``nm091234/image_56.jpg 18``.
* transform (callable, optional): A function that transforms the given PIL image
and returns a transformed image.
"""
def __init__(self, root, meta_file, transform=None):
self.root = root
self.transform = transform
with open(meta_file) as f:
meta_list = f.readlines()
self.num = len(meta_list)
self.metas = []
for line in meta_list:
path, cls = line.strip().split()
self.metas.append((path, int(cls)))
def __len__(self):
return self.num
def __getitem__(self, index):
filename = self.root + '/' + self.metas[index][0]
cls = self.metas[index][1]
with Image.open(filename) as img:
img = img.convert('RGB')
# transform
if self.transform is not None:
img = self.transform(img)
return img, cls
import torch
from torch.nn.modules.loss import _Loss
import torch.nn.functional as F
class LabelSmoothLoss(_Loss):
def __init__(self, smooth_ratio, num_classes):
super(LabelSmoothLoss, self).__init__()
self.smooth_ratio = smooth_ratio
self.v = self.smooth_ratio / num_classes
def forward(self, input, label):
one_hot = torch.zeros_like(input)
one_hot.fill_(self.v)
y = label.to(torch.long).view(-1, 1)
one_hot.scatter_(1, y, 1 - self.smooth_ratio + self.v)
loss = -torch.sum(F.log_softmax(input, 1) *
(one_hot.detach())) / input.size(0)
return loss
import numpy as np
import torch
import logging
logger = logging.getLogger()
def check_keys(model, checkpoint):
model_keys = set(model.state_dict().keys())
ckpt_keys = set(checkpoint['state_dict'].keys())
missing_keys = model_keys - ckpt_keys
for key in missing_keys:
logger.warning('missing key in model:{}'.format(key))
unexpected_keys = ckpt_keys - model_keys
for key in unexpected_keys:
logger.warning('unexpected key in checkpoint:{}'.format(key))
shared_keys = model_keys & ckpt_keys
for key in shared_keys:
logger.info('shared key:{}'.format(key))
return shared_keys
def accuracy(output, target, topk=(1, ), raw=False):
"""Computes the accuracy over the k top predictions for the specified values of k"""
with torch.no_grad():
maxk = max(topk)
_, pred = output.topk(maxk, 1, True, True)
pred = pred.t()
correct = pred.eq(target.view(1, -1).expand_as(pred))
res = []
for k in topk:
correct_k = correct[:k].contiguous().view(-1).float().sum(
0, keepdim=True)
if raw:
res.append(correct_k)
else:
res.append(correct_k.mul(100.0 / target.size(0)))
return res
class AverageMeter(object):
"""Computes and stores the average and current value
When length < 0 , save all history data """
def __init__(self, name, fmt=':f', length=1):
self.name = name
self.fmt = fmt
self.length = length
self.reset()
def reset(self):
if self.length > 1:
self.history = []
elif self.length < 0:
self.count = 0
self.sum = 0
self.avg = 0
self.val = 0
def update(self, val):
self.val = val
if self.length > 1:
self.history.append(val)
if len(self.history) > self.length:
del self.history[0]
self.avg = np.mean(self.history)
elif self.length < 0:
self.sum += val
self.count += 1
self.avg = self.sum / self.count
def __str__(self):
if self.length == 0 or self.length == 1:
fmtstr = '{name} {val' + self.fmt + '}'
else:
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
return fmtstr.format(**self.__dict__)
class ProgressMeter(object):
def __init__(self, num_batches, *meters, prefix=''):
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
self.meters = meters
self.prefix = prefix
def display(self, batch):
entries = [self.prefix + self.batch_fmtstr.format(batch)]
entries += [str(meter) for meter in self.meters]
logger.info(' '.join(entries))
def _get_batch_fmtstr(self, num_batches):
num_digits = len(str(num_batches // 1))
fmt = '{:' + str(num_digits) + 'd}'
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment