Commit 3a6df602 authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #1196 canceled with stages
"""
Train and eval functions used in main.py
"""
import math
import sys
from typing import Iterable, Optional
import torch
from timm.data import Mixup
from timm.utils import accuracy, ModelEma
from losses import DistillationLoss
import utils
def set_bn_state(model):
for m in model.modules():
if isinstance(m, torch.nn.modules.batchnorm._BatchNorm):
m.eval()
def train_one_epoch(model: torch.nn.Module, criterion: DistillationLoss,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, loss_scaler,
clip_grad: float = 0,
clip_mode: str = 'norm',
model_ema: Optional[ModelEma] = None, mixup_fn: Optional[Mixup] = None,
set_training_mode=True,
set_bn_eval=False,):
model.train(set_training_mode)
if set_bn_eval:
set_bn_state(model)
metric_logger = utils.MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', utils.SmoothedValue(
window_size=1, fmt='{value:.6f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = 100
for samples, targets in metric_logger.log_every(
data_loader, print_freq, header):
samples = samples.to(device, non_blocking=True)
targets = targets.to(device, non_blocking=True)
if mixup_fn is not None:
samples, targets = mixup_fn(samples, targets)
with torch.cuda.amp.autocast():
outputs = model(samples)
loss = criterion(samples, outputs, targets)
loss_value = loss.item()
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
sys.exit(1)
optimizer.zero_grad()
# this attribute is added by timm on one optimizer (adahessian)
is_second_order = hasattr(
optimizer, 'is_second_order') and optimizer.is_second_order
loss_scaler(loss, optimizer, clip_grad=clip_grad, clip_mode=clip_mode,
parameters=model.parameters(), create_graph=is_second_order)
torch.cuda.synchronize()
if model_ema is not None:
model_ema.update(model)
metric_logger.update(loss=loss_value)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(data_loader, model, device):
criterion = torch.nn.CrossEntropyLoss()
metric_logger = utils.MetricLogger(delimiter=" ")
header = 'Test:'
# switch to evaluation mode
model.eval()
for images, target in metric_logger.log_every(data_loader, 10, header):
images = images.to(device, non_blocking=True)
target = target.to(device, non_blocking=True)
# compute output
with torch.cuda.amp.autocast():
output = model(images)
loss = criterion(output, target)
acc1, acc5 = accuracy(output, target, topk=(1, 5))
batch_size = images.shape[0]
metric_logger.update(loss=loss.item())
metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print('* Acc@1 {top1.global_avg:.3f} Acc@5 {top5.global_avg:.3f} loss {losses.global_avg:.3f}'
.format(top1=metric_logger.acc1, top5=metric_logger.acc5, losses=metric_logger.loss))
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
python main.py --eval --model repvit_m1_1 --resume pretrain/repvit_m1_1_distill_300e.pth --data-path ~/imagenet
\ No newline at end of file
import torch
from timm import create_model
import model
import utils
import torch
import torchvision
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--model', default='repvit_m1_1', type=str)
parser.add_argument('--resolution', default=224, type=int)
parser.add_argument('--ckpt', default=None, type=str)
if __name__ == "__main__":
# Load a pre-trained version of MobileNetV2
args = parser.parse_args()
model = create_model(args.model, distillation=True)
if args.ckpt:
model.load_state_dict(torch.load(args.ckpt)['model'])
utils.replace_batchnorm(model)
model.eval()
# Trace the model with random data.
resolution = args.resolution
example_input = torch.rand(1, 3, resolution, resolution)
traced_model = torch.jit.trace(model, example_input)
out = traced_model(example_input)
import coremltools as ct
# Using image_input in the inputs parameter:
# Convert to Core ML neural network using the Unified Conversion API.
model = ct.convert(
traced_model,
inputs=[ct.ImageType(shape=example_input.shape)]
)
# Save the converted model.
model.save(f"coreml/{args.model}_{resolution}.mlmodel")
\ No newline at end of file
import torch
import time
from timm import create_model
import model
import utils
from fvcore.nn import FlopCountAnalysis
T0 = 5
T1 = 10
for n, batch_size, resolution in [
('repvit_m0_9', 1024, 224),
]:
inputs = torch.randn(1, 3, resolution,
resolution)
model = create_model(n, num_classes=1000)
utils.replace_batchnorm(model)
n_parameters = sum(p.numel()
for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters / 1e6)
flops = FlopCountAnalysis(model, inputs)
print("flops: ", flops.total() / 1e9)
\ No newline at end of file
"""
Implements the knowledge distillation loss, proposed in deit
"""
import torch
from torch.nn import functional as F
class DistillationLoss(torch.nn.Module):
"""
This module wraps a standard criterion and adds an extra knowledge distillation loss by
taking a teacher model prediction and using it as additional supervision.
"""
def __init__(self, base_criterion: torch.nn.Module, teacher_model: torch.nn.Module,
distillation_type: str, alpha: float, tau: float):
super().__init__()
self.base_criterion = base_criterion
self.teacher_model = teacher_model
assert distillation_type in ['none', 'soft', 'hard']
self.distillation_type = distillation_type
self.alpha = alpha
self.tau = tau
def forward(self, inputs, outputs, labels):
"""
Args:
inputs: The original inputs that are feed to the teacher model
outputs: the outputs of the model to be trained. It is expected to be
either a Tensor, or a Tuple[Tensor, Tensor], with the original output
in the first position and the distillation predictions as the second output
labels: the labels for the base criterion
"""
outputs_kd = None
if not isinstance(outputs, torch.Tensor):
# assume that the model outputs a tuple of [outputs, outputs_kd]
outputs, outputs_kd = outputs
base_loss = self.base_criterion(outputs, labels)
if self.distillation_type == 'none':
return base_loss
if outputs_kd is None:
raise ValueError("When knowledge distillation is enabled, the model is "
"expected to return a Tuple[Tensor, Tensor] with the output of the "
"class_token and the dist_token")
# don't backprop throught the teacher
with torch.no_grad():
teacher_outputs = self.teacher_model(inputs)
if self.distillation_type == 'soft':
T = self.tau
# taken from https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
# with slight modifications
distillation_loss = F.kl_div(
F.log_softmax(outputs_kd / T, dim=1),
F.log_softmax(teacher_outputs / T, dim=1),
reduction='sum',
log_target=True
) * (T * T) / outputs_kd.numel()
elif self.distillation_type == 'hard':
distillation_loss = F.cross_entropy(
outputs_kd, teacher_outputs.argmax(dim=1))
loss = base_loss * (1 - self.alpha) + distillation_loss * self.alpha
return loss
import argparse
import datetime
import numpy as np
import time
import torch
import torch.backends.cudnn as cudnn
import json
import os
from pathlib import Path
from timm.data import Mixup
from timm.models import create_model
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from timm.scheduler import create_scheduler
from timm.optim import create_optimizer
from timm.utils import NativeScaler, get_state_dict, ModelEma
from data.samplers import RASampler
from data.datasets import build_dataset
from data.threeaugment import new_data_aug_generator
from engine import train_one_epoch, evaluate
from losses import DistillationLoss
import model
import utils
def get_args_parser():
parser = argparse.ArgumentParser(
'RepViT training and evaluation script', add_help=False)
parser.add_argument('--batch-size', default=256, type=int)
parser.add_argument('--epochs', default=300, type=int)
# Model parameters
parser.add_argument('--model', default='repvit_m1_1', type=str, metavar='MODEL',
help='Name of model to train')
parser.add_argument('--input-size', default=224,
type=int, help='images input size')
parser.add_argument('--model-ema', action='store_true')
parser.add_argument(
'--no-model-ema', action='store_false', dest='model_ema')
parser.set_defaults(model_ema=True)
parser.add_argument('--model-ema-decay', type=float,
default=0.99996, help='')
parser.add_argument('--model-ema-force-cpu',
action='store_true', default=False, help='')
# Optimizer parameters
parser.add_argument('--opt', default='adamw', type=str, metavar='OPTIMIZER',
help='Optimizer (default: "adamw"')
parser.add_argument('--opt-eps', default=1e-8, type=float, metavar='EPSILON',
help='Optimizer Epsilon (default: 1e-8)')
parser.add_argument('--opt-betas', default=None, type=float, nargs='+', metavar='BETA',
help='Optimizer Betas (default: None, use opt default)')
parser.add_argument('--clip-grad', type=float, default=0.02, metavar='NORM',
help='Clip gradient norm (default: None, no clipping)')
parser.add_argument('--clip-mode', type=str, default='agc',
help='Gradient clipping mode. One of ("norm", "value", "agc")')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M',
help='SGD momentum (default: 0.9)')
parser.add_argument('--weight-decay', type=float, default=0.025,
help='weight decay (default: 0.025)')
# Learning rate schedule parameters
parser.add_argument('--sched', default='cosine', type=str, metavar='SCHEDULER',
help='LR scheduler (default: "cosine"')
parser.add_argument('--lr', type=float, default=1e-3, metavar='LR',
help='learning rate (default: 1e-3)')
parser.add_argument('--lr-noise', type=float, nargs='+', default=None, metavar='pct, pct',
help='learning rate noise on/off epoch percentages')
parser.add_argument('--lr-noise-pct', type=float, default=0.67, metavar='PERCENT',
help='learning rate noise limit percent (default: 0.67)')
parser.add_argument('--lr-noise-std', type=float, default=1.0, metavar='STDDEV',
help='learning rate noise std-dev (default: 1.0)')
parser.add_argument('--warmup-lr', type=float, default=1e-6, metavar='LR',
help='warmup learning rate (default: 1e-6)')
parser.add_argument('--min-lr', type=float, default=1e-5, metavar='LR',
help='lower lr bound for cyclic schedulers that hit 0 (1e-5)')
parser.add_argument('--decay-epochs', type=float, default=30, metavar='N',
help='epoch interval to decay LR')
parser.add_argument('--warmup-epochs', type=int, default=5, metavar='N',
help='epochs to warmup LR, if scheduler supports')
parser.add_argument('--cooldown-epochs', type=int, default=10, metavar='N',
help='epochs to cooldown LR at min_lr, after cyclic schedule ends')
parser.add_argument('--patience-epochs', type=int, default=10, metavar='N',
help='patience epochs for Plateau LR scheduler (default: 10')
parser.add_argument('--decay-rate', '--dr', type=float, default=0.1, metavar='RATE',
help='LR decay rate (default: 0.1)')
# Augmentation parameters
parser.add_argument('--ThreeAugment', action='store_true')
parser.add_argument('--color-jitter', type=float, default=0.4, metavar='PCT',
help='Color jitter factor (default: 0.4)')
parser.add_argument('--aa', type=str, default='rand-m9-mstd0.5-inc1', metavar='NAME',
help='Use AutoAugment policy. "v0" or "original". " + \
"(default: rand-m9-mstd0.5-inc1)'),
parser.add_argument('--smoothing', type=float, default=0.1,
help='Label smoothing (default: 0.1)')
parser.add_argument('--train-interpolation', type=str, default='bicubic',
help='Training interpolation (random, bilinear, bicubic default: "bicubic")')
parser.add_argument('--repeated-aug', action='store_true')
parser.add_argument('--no-repeated-aug',
action='store_false', dest='repeated_aug')
parser.set_defaults(repeated_aug=True)
# Random Erase params
parser.add_argument('--reprob', type=float, default=0.25, metavar='PCT',
help='Random erase prob (default: 0.25)')
parser.add_argument('--remode', type=str, default='pixel',
help='Random erase mode (default: "pixel")')
parser.add_argument('--recount', type=int, default=1,
help='Random erase count (default: 1)')
parser.add_argument('--resplit', action='store_true', default=False,
help='Do not random erase first (clean) augmentation split')
# Mixup params
parser.add_argument('--mixup', type=float, default=0.8,
help='mixup alpha, mixup enabled if > 0. (default: 0.8)')
parser.add_argument('--cutmix', type=float, default=1.0,
help='cutmix alpha, cutmix enabled if > 0. (default: 1.0)')
parser.add_argument('--cutmix-minmax', type=float, nargs='+', default=None,
help='cutmix min/max ratio, overrides alpha and enables cutmix if set (default: None)')
parser.add_argument('--mixup-prob', type=float, default=1.0,
help='Probability of performing mixup or cutmix when either/both is enabled')
parser.add_argument('--mixup-switch-prob', type=float, default=0.5,
help='Probability of switching to cutmix when both mixup and cutmix enabled')
parser.add_argument('--mixup-mode', type=str, default='batch',
help='How to apply mixup/cutmix params. Per "batch", "pair", or "elem"')
# Distillation parameters
parser.add_argument('--teacher-model', default='regnety_160', type=str, metavar='MODEL',
help='Name of teacher model to train (default: "regnety_160"')
parser.add_argument('--teacher-path', type=str,
default='https://dl.fbaipublicfiles.com/deit/regnety_160-a5fe301d.pth')
parser.add_argument('--distillation-type', default='hard',
choices=['none', 'soft', 'hard'], type=str, help="")
parser.add_argument('--distillation-alpha',
default=0.5, type=float, help="")
parser.add_argument('--distillation-tau', default=1.0, type=float, help="")
# Finetuning params
parser.add_argument('--finetune', default='',
help='finetune from checkpoint')
parser.add_argument('--set_bn_eval', action='store_true', default=False,
help='set BN layers to eval mode during finetuning.')
# Dataset parameters
parser.add_argument('--data-path', default='/root/FastBaseline/data/imagenet', type=str,
help='dataset path')
parser.add_argument('--data-set', default='IMNET', choices=['CIFAR', 'IMNET', 'INAT', 'INAT19'],
type=str, help='Image Net dataset path')
parser.add_argument('--inat-category', default='name',
choices=['kingdom', 'phylum', 'class', 'order',
'supercategory', 'family', 'genus', 'name'],
type=str, help='semantic granularity')
parser.add_argument('--output_dir', default='checkpoints',
help='path where to save, empty for no saving')
parser.add_argument('--device', default='cuda',
help='device to use for training / testing')
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--resume', default='', help='resume from checkpoint')
parser.add_argument('--start_epoch', default=0, type=int, metavar='N',
help='start epoch')
parser.add_argument('--eval', action='store_true',
help='Perform evaluation only')
parser.add_argument('--dist-eval', action='store_true',
default=False, help='Enabling distributed evaluation')
parser.add_argument('--num_workers', default=10, type=int)
parser.add_argument('--pin-mem', action='store_true',
help='Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.')
parser.add_argument('--no-pin-mem', action='store_false', dest='pin_mem',
help='')
parser.set_defaults(pin_mem=True)
# training parameters
parser.add_argument('--world_size', default=1, type=int,
help='number of distributed processes')
parser.add_argument('--dist_url', default='env://',
help='url used to set up distributed training')
parser.add_argument('--save_freq', default=1, type=int,
help='frequency of model saving')
parser.add_argument('--deploy', action='store_true', default=False)
parser.add_argument('--project', default='repvit', type=str)
return parser
import wandb
def main(args):
utils.init_distributed_mode(args)
if utils.is_main_process() and not args.eval:
wandb.init(project=args.project, config=args)
wandb.run.log_code('model')
if args.distillation_type != 'none' and args.finetune and not args.eval:
raise NotImplementedError(
"Finetuning with distillation not yet supported")
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
# random.seed(seed)
cudnn.benchmark = True
dataset_train, args.nb_classes = build_dataset(is_train=True, args=args)
dataset_val, _ = build_dataset(is_train=False, args=args)
if True: # args.distributed:
num_tasks = utils.get_world_size()
global_rank = utils.get_rank()
if args.repeated_aug:
sampler_train = RASampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
else:
sampler_train = torch.utils.data.DistributedSampler(
dataset_train, num_replicas=num_tasks, rank=global_rank, shuffle=True
)
if args.dist_eval:
if len(dataset_val) % num_tasks != 0:
print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. '
'This will slightly alter validation results as extra duplicate entries are added to achieve '
'equal num of samples per-process.')
sampler_val = torch.utils.data.DistributedSampler(
dataset_val, num_replicas=num_tasks, rank=global_rank, shuffle=False)
else:
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
data_loader_train = torch.utils.data.DataLoader(
dataset_train, sampler=sampler_train,
batch_size=args.batch_size,
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=True,
)
if args.ThreeAugment:
data_loader_train.dataset.transform = new_data_aug_generator(args)
data_loader_val = torch.utils.data.DataLoader(
dataset_val, sampler=sampler_val,
batch_size=int(1.5 * args.batch_size),
num_workers=args.num_workers,
pin_memory=args.pin_mem,
drop_last=False
)
mixup_fn = None
mixup_active = args.mixup > 0 or args.cutmix > 0. or args.cutmix_minmax is not None
if mixup_active:
mixup_fn = Mixup(
mixup_alpha=args.mixup, cutmix_alpha=args.cutmix, cutmix_minmax=args.cutmix_minmax,
prob=args.mixup_prob, switch_prob=args.mixup_switch_prob, mode=args.mixup_mode,
label_smoothing=args.smoothing, num_classes=args.nb_classes)
print(f"Creating model: {args.model}")
model = create_model(
args.model,
num_classes=args.nb_classes,
distillation=(args.distillation_type != 'none'),
pretrained=False,
)
export_onnx(model, args.output_dir)
if args.finetune:
if args.finetune.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.finetune, map_location='cpu', check_hash=True)
else:
print("Loading local checkpoint at {}".format(args.finetune))
checkpoint = torch.load(args.finetune, map_location='cpu')
checkpoint_model = checkpoint['model']
state_dict = model.state_dict()
for k in ['head.l.weight', 'head.l.bias',
'head_dist.l.weight', 'head_dist.l.bias']:
if k in checkpoint_model and checkpoint_model[k].shape != state_dict[k].shape:
print(f"Removing key {k} from pretrained checkpoint")
del checkpoint_model[k]
msg = model.load_state_dict(checkpoint_model, strict=False)
print(msg)
model.to(device)
model_ema = None
if args.model_ema:
# Important to create EMA model after cuda(), DP wrapper, and AMP but
# before SyncBN and DDP wrapper
model_ema = ModelEma(
model,
decay=args.model_ema_decay,
device='cpu' if args.model_ema_force_cpu else '',
resume='')
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel()
for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
linear_scaled_lr = args.lr * args.batch_size * utils.get_world_size() / 512.0
args.lr = linear_scaled_lr
optimizer = create_optimizer(args, model_without_ddp)
loss_scaler = NativeScaler()
lr_scheduler, _ = create_scheduler(args, optimizer)
criterion = LabelSmoothingCrossEntropy()
if args.mixup > 0.:
# smoothing is handled with mixup label transform
criterion = SoftTargetCrossEntropy()
elif args.smoothing:
criterion = LabelSmoothingCrossEntropy(smoothing=args.smoothing)
else:
criterion = torch.nn.CrossEntropyLoss()
teacher_model = None
if args.distillation_type != 'none':
assert args.teacher_path, 'need to specify teacher-path when using distillation'
print(f"Creating teacher model: {args.teacher_model}")
teacher_model = create_model(
args.teacher_model,
pretrained=False,
num_classes=args.nb_classes,
global_pool='avg',
)
if args.teacher_path.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.teacher_path, map_location='cpu', check_hash=True)
else:
checkpoint = torch.load(args.teacher_path, map_location='cpu')
teacher_model.load_state_dict(checkpoint['model'])
teacher_model.to(device)
teacher_model.eval()
# wrap the criterion in our custom DistillationLoss, which
# just dispatches to the original criterion if args.distillation_type is
# 'none'
criterion = DistillationLoss(
criterion, teacher_model, args.distillation_type, args.distillation_alpha, args.distillation_tau
)
output_dir = Path(args.output_dir)
if args.output_dir and utils.is_main_process():
with (output_dir / "model.txt").open("a") as f:
f.write(str(model))
print(str(model))
if args.output_dir and utils.is_main_process():
with (output_dir / "args.txt").open("a") as f:
f.write(json.dumps(args.__dict__, indent=2) + "\n")
print(json.dumps(args.__dict__, indent=2) + "\n")
if args.resume:
if args.resume.startswith('https'):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location='cpu', check_hash=True)
else:
print("Loading local checkpoint at {}".format(args.resume))
checkpoint = torch.load(args.resume, map_location='cpu')
msg = model_without_ddp.load_state_dict(checkpoint['model'], strict=True)
print(msg)
if not args.eval and 'optimizer' in checkpoint and 'lr_scheduler' in checkpoint and 'epoch' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch'] + 1
if args.model_ema:
utils._load_checkpoint_for_ema(
model_ema, checkpoint['model_ema'])
if 'scaler' in checkpoint:
loss_scaler.load_state_dict(checkpoint['scaler'])
if args.eval:
utils.replace_batchnorm(model) # Users may choose whether to merge Conv-BN layers during eval
print(f"Evaluating model: {args.model}")
test_stats = evaluate(data_loader_val, model, device)
print(
f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
return
print(f"Start training for {args.epochs} epochs")
start_time = time.time()
max_accuracy = 0.0
max_accuracy_ema = 0.0
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
data_loader_train.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
model, criterion, data_loader_train,
optimizer, device, epoch, loss_scaler,
args.clip_grad, args.clip_mode, model_ema, mixup_fn,
# set_training_mode=args.finetune == '' # keep in eval mode during finetuning
set_training_mode=True,
set_bn_eval=args.set_bn_eval, # set bn to eval if finetune
)
lr_scheduler.step(epoch)
test_stats = evaluate(data_loader_val, model, device)
print(
f"Accuracy of the network on the {len(dataset_val)} test images: {test_stats['acc1']:.1f}%")
if args.output_dir:
ckpt_path = os.path.join(output_dir, 'checkpoint_'+str(epoch)+'.pth')
checkpoint_paths = [ckpt_path]
print("Saving checkpoint to {}".format(ckpt_path))
for checkpoint_path in checkpoint_paths:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict(),
'args': args,
}, checkpoint_path)
remove_epoch = epoch - 3
if remove_epoch >= 0 and utils.is_main_process():
os.remove(os.path.join(output_dir, 'checkpoint_'+str(remove_epoch)+'.pth'))
if max_accuracy < test_stats["acc1"]:
utils.save_on_master({
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'epoch': epoch,
'model_ema': get_state_dict(model_ema),
'scaler': loss_scaler.state_dict(),
'args': args,
}, os.path.join(output_dir, 'checkpoint_best.pth'))
max_accuracy = max(max_accuracy, test_stats["acc1"])
print(f'Max accuracy: {max_accuracy:.2f}%')
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if utils.is_main_process():
wandb.log({**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'max_accuracy': max_accuracy}, step=epoch)
if args.output_dir and utils.is_main_process():
with (output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if utils.is_main_process():
wandb.finish()
def export_onnx(model, output_dir):
# if utils.is_main_process():
# dummy_input = torch.randn(1, 3, 224, 224)
# torch.onnx.export(model, dummy_input, f"{output_dir}/model.onnx")
# wandb.save(f"{output_dir}/model.onnx")
pass
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'RepViT training and evaluation script', parents=[get_args_parser()])
args = parser.parse_args()
if args.resume and not args.eval:
args.output_dir = '/'.join(args.resume.split('/')[:-1])
elif args.output_dir:
args.output_dir = args.output_dir + f"/{args.model}/" + datetime.datetime.now().strftime('%Y_%m_%d_%H_%M_%S')
Path(args.output_dir).mkdir(parents=True, exist_ok=True)
else:
assert(False)
main(args)
# 模型编码
modelCode=698
# 模型名称
modelName=repvit_pytorch
# 模型描述
modelDescription=RepViT在iPhone 12上以1ms的延迟实现了超过80%的top-1准确率,为当前多个SOTA实例分割算法的backbone。
# 应用场景
appScenario=训练,制造,电商,医疗,能源,教育
# 框架类型
frameType=pytorch
import model.repvit
\ No newline at end of file
import torch.nn as nn
def _make_divisible(v, divisor, min_value=None):
"""
This function is taken from the original tf repo.
It ensures that all layers have a channel number that is divisible by 8
It can be seen here:
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
:param v:
:param divisor:
:param min_value:
:return:
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
from timm.models.layers import SqueezeExcite
import torch
class Conv2d_BN(torch.nn.Sequential):
def __init__(self, a, b, ks=1, stride=1, pad=0, dilation=1,
groups=1, bn_weight_init=1, resolution=-10000):
super().__init__()
self.add_module('c', torch.nn.Conv2d(
a, b, ks, stride, pad, dilation, groups, bias=False))
self.add_module('bn', torch.nn.BatchNorm2d(b))
torch.nn.init.constant_(self.bn.weight, bn_weight_init)
torch.nn.init.constant_(self.bn.bias, 0)
@torch.no_grad()
def fuse(self):
c, bn = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = c.weight * w[:, None, None, None]
b = bn.bias - bn.running_mean * bn.weight / \
(bn.running_var + bn.eps)**0.5
m = torch.nn.Conv2d(w.size(1) * self.c.groups, w.size(
0), w.shape[2:], stride=self.c.stride, padding=self.c.padding, dilation=self.c.dilation, groups=self.c.groups,
device=c.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class Residual(torch.nn.Module):
def __init__(self, m, drop=0.):
super().__init__()
self.m = m
self.drop = drop
def forward(self, x):
if self.training and self.drop > 0:
return x + self.m(x) * torch.rand(x.size(0), 1, 1, 1,
device=x.device).ge_(self.drop).div(1 - self.drop).detach()
else:
return x + self.m(x)
@torch.no_grad()
def fuse(self):
if isinstance(self.m, Conv2d_BN):
m = self.m.fuse()
assert(m.groups == m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1,1,1,1])
m.weight += identity.to(m.weight.device)
return m
elif isinstance(self.m, torch.nn.Conv2d):
m = self.m
assert(m.groups != m.in_channels)
identity = torch.ones(m.weight.shape[0], m.weight.shape[1], 1, 1)
identity = torch.nn.functional.pad(identity, [1,1,1,1])
m.weight += identity.to(m.weight.device)
return m
else:
return self
class RepVGGDW(torch.nn.Module):
def __init__(self, ed) -> None:
super().__init__()
self.conv = Conv2d_BN(ed, ed, 3, 1, 1, groups=ed)
self.conv1 = torch.nn.Conv2d(ed, ed, 1, 1, 0, groups=ed)
self.dim = ed
self.bn = torch.nn.BatchNorm2d(ed)
def forward(self, x):
return self.bn((self.conv(x) + self.conv1(x)) + x)
@torch.no_grad()
def fuse(self):
conv = self.conv.fuse()
conv1 = self.conv1
conv_w = conv.weight
conv_b = conv.bias
conv1_w = conv1.weight
conv1_b = conv1.bias
conv1_w = torch.nn.functional.pad(conv1_w, [1,1,1,1])
identity = torch.nn.functional.pad(torch.ones(conv1_w.shape[0], conv1_w.shape[1], 1, 1, device=conv1_w.device), [1,1,1,1])
final_conv_w = conv_w + conv1_w + identity
final_conv_b = conv_b + conv1_b
conv.weight.data.copy_(final_conv_w)
conv.bias.data.copy_(final_conv_b)
bn = self.bn
w = bn.weight / (bn.running_var + bn.eps)**0.5
w = conv.weight * w[:, None, None, None]
b = bn.bias + (conv.bias - bn.running_mean) * bn.weight / \
(bn.running_var + bn.eps)**0.5
conv.weight.data.copy_(w)
conv.bias.data.copy_(b)
return conv
class RepViTBlock(nn.Module):
def __init__(self, inp, hidden_dim, oup, kernel_size, stride, use_se, use_hs):
super(RepViTBlock, self).__init__()
assert stride in [1, 2]
self.identity = stride == 1 and inp == oup
assert(hidden_dim == 2 * inp)
if stride == 2:
self.token_mixer = nn.Sequential(
Conv2d_BN(inp, inp, kernel_size, stride, (kernel_size - 1) // 2, groups=inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
Conv2d_BN(inp, oup, ks=1, stride=1, pad=0)
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(oup, 2 * oup, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(2 * oup, oup, 1, 1, 0, bn_weight_init=0),
))
else:
assert(self.identity)
self.token_mixer = nn.Sequential(
RepVGGDW(inp),
SqueezeExcite(inp, 0.25) if use_se else nn.Identity(),
)
self.channel_mixer = Residual(nn.Sequential(
# pw
Conv2d_BN(inp, hidden_dim, 1, 1, 0),
nn.GELU() if use_hs else nn.GELU(),
# pw-linear
Conv2d_BN(hidden_dim, oup, 1, 1, 0, bn_weight_init=0),
))
def forward(self, x):
return self.channel_mixer(self.token_mixer(x))
from timm.models.vision_transformer import trunc_normal_
class BN_Linear(torch.nn.Sequential):
def __init__(self, a, b, bias=True, std=0.02):
super().__init__()
self.add_module('bn', torch.nn.BatchNorm1d(a))
self.add_module('l', torch.nn.Linear(a, b, bias=bias))
trunc_normal_(self.l.weight, std=std)
if bias:
torch.nn.init.constant_(self.l.bias, 0)
@torch.no_grad()
def fuse(self):
bn, l = self._modules.values()
w = bn.weight / (bn.running_var + bn.eps)**0.5
b = bn.bias - self.bn.running_mean * \
self.bn.weight / (bn.running_var + bn.eps)**0.5
w = l.weight * w[None, :]
if l.bias is None:
b = b @ self.l.weight.T
else:
b = (l.weight @ b[:, None]).view(-1) + self.l.bias
m = torch.nn.Linear(w.size(1), w.size(0), device=l.weight.device)
m.weight.data.copy_(w)
m.bias.data.copy_(b)
return m
class Classfier(nn.Module):
def __init__(self, dim, num_classes, distillation=True):
super().__init__()
self.classifier = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
self.distillation = distillation
if distillation:
self.classifier_dist = BN_Linear(dim, num_classes) if num_classes > 0 else torch.nn.Identity()
def forward(self, x):
if self.distillation:
x = self.classifier(x), self.classifier_dist(x)
if not self.training:
x = (x[0] + x[1]) / 2
else:
x = self.classifier(x)
return x
@torch.no_grad()
def fuse(self):
classifier = self.classifier.fuse()
if self.distillation:
classifier_dist = self.classifier_dist.fuse()
classifier.weight += classifier_dist.weight
classifier.bias += classifier_dist.bias
classifier.weight /= 2
classifier.bias /= 2
return classifier
else:
return classifier
class RepViT(nn.Module):
def __init__(self, cfgs, num_classes=1000, distillation=False):
super(RepViT, self).__init__()
# setting of inverted residual blocks
self.cfgs = cfgs
# building first layer
input_channel = self.cfgs[0][2]
patch_embed = torch.nn.Sequential(Conv2d_BN(3, input_channel // 2, 3, 2, 1), torch.nn.GELU(),
Conv2d_BN(input_channel // 2, input_channel, 3, 2, 1))
layers = [patch_embed]
# building inverted residual blocks
block = RepViTBlock
for k, t, c, use_se, use_hs, s in self.cfgs:
output_channel = _make_divisible(c, 8)
exp_size = _make_divisible(input_channel * t, 8)
layers.append(block(input_channel, exp_size, output_channel, k, s, use_se, use_hs))
input_channel = output_channel
self.features = nn.ModuleList(layers)
self.classifier = Classfier(output_channel, num_classes, distillation)
def forward(self, x):
# x = self.features(x)
for f in self.features:
x = f(x)
x = torch.nn.functional.adaptive_avg_pool2d(x, 1).flatten(1)
x = self.classifier(x)
return x
from timm.models import register_model
@register_model
def repvit_m0_6(pretrained=False, num_classes = 1000, distillation=False):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
[3, 2, 40, 1, 0, 1],
[3, 2, 40, 0, 0, 1],
[3, 2, 80, 0, 0, 2],
[3, 2, 80, 1, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 160, 0, 1, 2],
[3, 2, 160, 1, 1, 1],
[3, 2, 160, 0, 1, 1],
[3, 2, 160, 1, 1, 1],
[3, 2, 160, 0, 1, 1],
[3, 2, 160, 1, 1, 1],
[3, 2, 160, 0, 1, 1],
[3, 2, 160, 1, 1, 1],
[3, 2, 160, 0, 1, 1],
[3, 2, 160, 0, 1, 1],
[3, 2, 320, 0, 1, 2],
[3, 2, 320, 1, 1, 1],
]
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
@register_model
def repvit_m0_9(pretrained=False, num_classes = 1000, distillation=False):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 48, 1, 0, 1],
[3, 2, 48, 0, 0, 1],
[3, 2, 48, 0, 0, 1],
[3, 2, 96, 0, 0, 2],
[3, 2, 96, 1, 0, 1],
[3, 2, 96, 0, 0, 1],
[3, 2, 96, 0, 0, 1],
[3, 2, 192, 0, 1, 2],
[3, 2, 192, 1, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 192, 1, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 192, 1, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 192, 1, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 192, 1, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 192, 1, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 192, 1, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 192, 0, 1, 1],
[3, 2, 384, 0, 1, 2],
[3, 2, 384, 1, 1, 1],
[3, 2, 384, 0, 1, 1]
]
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
@register_model
def repvit_m1_0(pretrained=False, num_classes = 1000, distillation=False):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 56, 1, 0, 1],
[3, 2, 56, 0, 0, 1],
[3, 2, 56, 0, 0, 1],
[3, 2, 112, 0, 0, 2],
[3, 2, 112, 1, 0, 1],
[3, 2, 112, 0, 0, 1],
[3, 2, 112, 0, 0, 1],
[3, 2, 224, 0, 1, 2],
[3, 2, 224, 1, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 224, 1, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 224, 1, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 224, 1, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 224, 1, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 224, 1, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 224, 1, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 224, 0, 1, 1],
[3, 2, 448, 0, 1, 2],
[3, 2, 448, 1, 1, 1],
[3, 2, 448, 0, 1, 1]
]
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
@register_model
def repvit_m1_1(pretrained=False, num_classes = 1000, distillation=False):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 64, 1, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 128, 0, 0, 2],
[3, 2, 128, 1, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 256, 0, 1, 2],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 512, 0, 1, 2],
[3, 2, 512, 1, 1, 1],
[3, 2, 512, 0, 1, 1]
]
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
@register_model
def repvit_m1_5(pretrained=False, num_classes = 1000, distillation=False):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 64, 1, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 64, 1, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 64, 0, 0, 1],
[3, 2, 128, 0, 0, 2],
[3, 2, 128, 1, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 128, 1, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 128, 0, 0, 1],
[3, 2, 256, 0, 1, 2],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 1, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 256, 0, 1, 1],
[3, 2, 512, 0, 1, 2],
[3, 2, 512, 1, 1, 1],
[3, 2, 512, 0, 1, 1],
[3, 2, 512, 1, 1, 1],
[3, 2, 512, 0, 1, 1]
]
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
@register_model
def repvit_m2_3(pretrained=False, num_classes = 1000, distillation=False):
"""
Constructs a MobileNetV3-Large model
"""
cfgs = [
# k, t, c, SE, HS, s
[3, 2, 80, 1, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 80, 1, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 80, 1, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 80, 0, 0, 1],
[3, 2, 160, 0, 0, 2],
[3, 2, 160, 1, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 160, 1, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 160, 1, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 160, 0, 0, 1],
[3, 2, 320, 0, 1, 2],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 320, 1, 1, 1],
[3, 2, 320, 0, 1, 1],
# [3, 2, 320, 1, 1, 1],
# [3, 2, 320, 0, 1, 1],
[3, 2, 320, 0, 1, 1],
[3, 2, 640, 0, 1, 2],
[3, 2, 640, 1, 1, 1],
[3, 2, 640, 0, 1, 1],
# [3, 2, 640, 1, 1, 1],
# [3, 2, 640, 0, 1, 1]
]
return RepViT(cfgs, num_classes=num_classes, distillation=distillation)
\ No newline at end of file
torch
timm==0.5.4
fvcore
\ No newline at end of file
*.pyc
*.pyo
*.pyd
__py
**/__pycache__/
repvit_sam.egg-info
weights/*.pt
*.pt
*.onnx
\ No newline at end of file
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
# Contributing to segment-anything
We want to make contributing to this project as easy and transparent as
possible.
## Pull Requests
We actively welcome your pull requests.
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes.
5. Make sure your code lints, using the `linter.sh` script in the project's root directory. Linting requires `black==23.*`, `isort==5.12.0`, `flake8`, and `mypy`.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Facebook's open source projects.
Complete your CLA here: <https://code.facebook.com/cla>
## Issues
We use GitHub issues to track public bugs. Please ensure your description is
clear and has sufficient instructions to be able to reproduce the issue.
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe
disclosure of security bugs. In those cases, please go through the process
outlined on that page and do not file a public issue.
## License
By contributing to segment-anything, you agree that your contributions will be licensed
under the LICENSE file in the root directory of this source tree.
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# [RepViT-SAM: Towards Real-Time Segmenting Anything](https://arxiv.org/abs/2312.05760)
Official PyTorch implementation of **RepViT-SAM**, from the following paper:
[RepViT-SAM: Towards Real-Time Segmenting Anything](https://arxiv.org/abs/2312.05760).\
Ao Wang, Hui Chen, Zijia Lin, Jungong Han, and Guiguang Ding\
[[`arXiv`](https://arxiv.org/abs/2312.05760)]
<p align="center">
<img src="figures/comparison.png" width=70%> <br>
Models are deployed on iPhone 12 with Core ML Tools to get latency.
</p>
<details>
<summary>
<font size="+1">Abstract</font>
</summary>
Segment Anything Model (SAM) has shown impressive zero-shot transfer performance for various computer vision tasks recently. However, its heavy computation costs remain daunting for practical applications. MobileSAM proposes to replace the heavyweight image encoder in SAM with TinyViT by employing distillation, which results in a significant reduction in computational requirements. However, its deployment on resource-constrained mobile devices still encounters challenges due to the substantial memory and computational overhead caused by self-attention mechanisms. Recently, RepViT achieves the state-of-the-art performance and latency trade-off on mobile devices by incorporating efficient architectural designs of ViTs into CNNs. Here, to achieve real-time segmenting anything on mobile devices, following, we replace the heavyweight image encoder in SAM with RepViT model, ending up with the RepViT-SAM model. Extensive experiments show that RepViT-SAM can enjoy significantly better zero-shot transfer capability than MobileSAM, along with nearly $10\times$ faster inference speed.
</details>
<br/>
## Installation
```bash
pip install -e .
# download pretrained checkpoint
mkdir weights && cd weights
wget https://github.com/THU-MIG/RepViT/releases/download/v1.0/repvit_sam.pt
```
## Demo
Our Hugging Face demo is [here](https://huggingface.co/spaces/jameslahm/repvit-sam)
```
python app/app.py
```
## CoreML export
Please refer to [coreml_example.ipynb](./notebooks/coreml_example.ipynb)
## Latency comparisons
Comparison between RepViT-SAM and others in terms of latency. The latency (ms) is measured with the standard resolution of 1024 $\times$ 1024 on iPhone 12 and Macbook M1 Pro by Core ML Tools. OOM means out of memory.
<table class="tg">
<thead>
<tr>
<th class="tg-c3ow" rowspan="2">Platform</th>
<th class="tg-c3ow" colspan="3">Image encoder</th>
<th class="tg-c3ow" rowspan="2">Mask decoder</th>
</tr>
<tr>
<th class="tg-c3ow" rowspan="1">RepViT-SAM</th>
<th class="tg-c3ow" rowspan="1">MobileSAM</th>
<th class="tg-c3ow" rowspan="1">ViT-B-SAM</th>
</tr>
</thead>
<tbody>
<tr>
<td class="tg-c3ow">iPhone</td>
<td class="tg-c3ow"><b>48.9ms</b></td>
<td class="tg-c3ow">OOM</td>
<td class="tg-c3ow">OOM</td>
<td class="tg-c3ow">11.6ms</td>
</tr>
<tr>
<td class="tg-c3ow">Macbook</td>
<td class="tg-c3ow"><b>44.8ms</b></td>
<td class="tg-c3ow">482.2ms</td>
<td class="tg-c3ow">6249.5ms</td>
<td class="tg-c3ow">11.8ms</td>
</tr>
</tbody>
</table>
## Zero-shot edge detection
Comparison results on BSDS500.
<table class="tg">
<thead>
<tr>
<th class="tg-c3ow" rowspan="2">Model</th>
<th class="tg-c3ow" colspan="3">zero-shot edge detection</th>
</tr>
<tr>
<th class="tg-c3ow">ODS</th>
<th class="tg-c3ow">OIS</th>
<th class="tg-c3ow">AP</th>
</tr>
</thead>
<tbody>
<tr>
<td class="tg-c3ow">ViT-H-SAM</td>
<td class="tg-c3ow"><b>.768</b></td>
<td class="tg-c3ow"><b>.786</b></td>
<td class="tg-c3ow"><b>.794</b></td>
</tr>
<tr>
<td class="tg-c3ow">ViT-B-SAM</td>
<td class="tg-c3ow">.743</td>
<td class="tg-c3ow">.764</td>
<td class="tg-c3ow">.726</td>
</tr>
<tr>
<td class="tg-c3ow">MobileSAM</td>
<td class="tg-c3ow">.756</td>
<td class="tg-c3ow">.768</td>
<td class="tg-c3ow">.746</td>
</tr>
<tr>
<td class="tg-c3ow">RepViT-SAM</td>
<td class="tg-c3ow"><ins>.764</ins></td>
<td class="tg-c3ow"><ins>.786</ins></td>
<td class="tg-c3ow"><ins>.773</ins></td>
</tr>
</tbody>
</table>
## Zero-shot instance segmentation and SegInW
Comparison results on COCO and SegInW.
<table class="tg">
<thead>
<tr>
<th class="tg-c3ow" rowspan="2">Model</th>
<th class="tg-c3ow" colspan="4">zero-shot instance segmentation</th>
<th class="tg-c3ow">SegInW</th>
</tr>
<tr>
<th class="tg-c3ow">AP</th>
<th class="tg-c3ow">$AP^{S}$</th>
<th class="tg-c3ow">$AP^{M}$</th>
<th class="tg-c3ow">$AP^{L}$</th>
<th class="tg-c3ow">Mean AP</th>
</tr>
</thead>
<tbody>
<tr>
<td class="tg-c3ow">ViT-H-SAM</td>
<td class="tg-c3ow"><b>46.8</b></td>
<td class="tg-c3ow"><b>31.8</b></td>
<td class="tg-c3ow"><b>51.0</b></td>
<td class="tg-c3ow"><b>63.6</b></td>
<td class="tg-c3ow"><b>48.7</b></td>
</tr>
<tr>
<td class="tg-c3ow">ViT-B-SAM</td>
<td class="tg-c3ow">42.5</td>
<td class="tg-c3ow"><ins>29.8</ins></td>
<td class="tg-c3ow">47.0</td>
<td class="tg-c3ow">56.8</td>
<td class="tg-c3ow">44.8</td>
</tr>
<tr>
<td class="tg-c3ow">MobileSAM</td>
<td class="tg-c3ow">42.7</td>
<td class="tg-c3ow">27.0</td>
<td class="tg-c3ow">46.5</td>
<td class="tg-c3ow">61.1</td>
<td class="tg-c3ow">43.9</td>
</tr>
<tr>
<td class="tg-c3ow">RepViT-SAM</td>
<td class="tg-c3ow"><ins>44.4</ins></td>
<td class="tg-c3ow">29.1</td>
<td class="tg-c3ow"><ins>48.6</ins></td>
<td class="tg-c3ow"><ins>61.4</ins></td>
<td class="tg-c3ow"><ins>46.1</ins></td>
</tr>
</tbody>
</table>
## Zero-shot video object/instance segmentation
Comparison results on DAVIS 2017 and UVO.
<table class="tg">
<thead>
<tr>
<th class="tg-c3ow" rowspan="2">Model</th>
<th class="tg-c3ow" colspan="3">z.s. VOS</th>
<th class="tg-c3ow">z.s. VIS</th>
</tr>
<tr>
<th class="tg-c3ow">$\mathcal{J\&amp;F}$</th>
<th class="tg-c3ow">$\mathcal{J}$</th>
<th class="tg-c3ow">$\mathcal{F}$</th>
<th class="tg-c3ow">AR100</th>
</tr>
</thead>
<tbody>
<tr>
<td class="tg-c3ow">ViT-H-SAM</td>
<td class="tg-c3ow"><b>77.4</b></td>
<td class="tg-c3ow"><b>74.6</b></td>
<td class="tg-c3ow"><b>80.2</b></td>
<td class="tg-c3ow"><b>28.8</b></td>
</tr>
<tr>
<td class="tg-c3ow">ViT-B-SAM</td>
<td class="tg-c3ow">71.3</td>
<td class="tg-c3ow">68.5</td>
<td class="tg-c3ow">74.1</td>
<td class="tg-c3ow">19.1</td>
</tr>
<tr>
<td class="tg-c3ow">MobileSAM</td>
<td class="tg-c3ow">71.1</td>
<td class="tg-c3ow">68.6</td>
<td class="tg-c3ow">73.6</td>
<td class="tg-c3ow">22.7</td>
</tr>
<tr>
<td class="tg-c3ow">RepViT-SAM</td>
<td class="tg-c3ow"><ins>73.5</ins></td>
<td class="tg-c3ow"><ins>71.0</ins></td>
<td class="tg-c3ow"><ins>76.1</ins></td>
<td class="tg-c3ow"><ins>25.3</ins></td>
</tr>
</tbody>
</table>
## Zero-shot salient object segmentation
Comparison results on DUTS.
## Zero-shot anomaly detection
Comparison results on MVTec.
<table class="tg">
<thead>
<tr>
<th class="tg-c3ow" rowspan="2">Model</th>
<th class="tg-c3ow">z.s. s.o.s.</th>
<th class="tg-c3ow">z.s. a.d.</th>
</tr>
<tr>
<th class="tg-c3ow">$\mathcal{M}$ $\downarrow$</th>
<th class="tg-c3ow">$\mathcal{F}_{p}$</th>
</tr>
</thead>
<tbody>
<tr>
<td class="tg-c3ow">ViT-H-SAM</td>
<td class="tg-c3ow"><b>0.046</b></td>
<td class="tg-c3ow"><ins>37.65</ins></td>
</tr>
<tr>
<td class="tg-c3ow">ViT-B-SAM</td>
<td class="tg-c3ow">0.121</td>
<td class="tg-c3ow">36.62</td>
</tr>
<tr>
<td class="tg-c3ow">MobileSAM</td>
<td class="tg-c3ow">0.147</td>
<td class="tg-c3ow">36.44</td>
</tr>
<tr>
<td class="tg-c3ow">RepViT-SAM</td>
<td class="tg-c3ow"><ins>0.066</ins></td>
<td class="tg-c3ow"><b>37.96</b></td>
</tr>
</tbody>
</table>
## Acknowledgement
The code base is partly built with [SAM](https://github.com/facebookresearch/segment-anything) and [MobileSAM](https://github.com/ChaoningZhang/MobileSAM).
Thanks for the great implementations!
## Citation
If our code or models help your work, please cite our paper:
```BibTeX
@misc{wang2023repvitsam,
title={RepViT-SAM: Towards Real-Time Segmenting Anything},
author={Ao Wang and Hui Chen and Zijia Lin and Jungong Han and Guiguang Ding},
year={2023},
eprint={2312.05760},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```
*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
assets/sa_1309.jpg filter=lfs diff=lfs merge=lfs -text
assets/sa_192.jpg filter=lfs diff=lfs merge=lfs -text
assets/sa_414.jpg filter=lfs diff=lfs merge=lfs -text
assets/sa_862.jpg filter=lfs diff=lfs merge=lfs -text
import os
import gradio as gr
import numpy as np
import torch
from repvit_sam import SamAutomaticMaskGenerator, SamPredictor, sam_model_registry
from PIL import ImageDraw
from utils.tools import box_prompt, format_results, point_prompt
from utils.tools_gradio import fast_process
# Most of our demo code is from [FastSAM Demo](https://huggingface.co/spaces/An-619/FastSAM). Huge thanks for AN-619.
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Load the pre-trained model
sam_checkpoint = "weights/repvit_sam.pt"
model_type = "repvit"
repvit_sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
repvit_sam = repvit_sam.to(device=device)
repvit_sam.eval()
mask_generator = SamAutomaticMaskGenerator(repvit_sam)
predictor = SamPredictor(repvit_sam)
# Description
title = "<center><strong><font size='8'>RepViT-SAM<font></strong></center>"
description_e = """This is a demo of [RepViT-SAM](https://github.com/THU-MIG/RepViT).
We will provide box mode soon.
Enjoy!
"""
description_p = """ Instructions for point mode
0. Restart by click the Restart button
1. Select a point with Add Mask for the foreground (Must)
2. Select a point with Remove Area for the background (Optional)
3. Click the Start Segmenting.
Github [link](https://github.com/THU-MIG/RepViT)
"""
examples = [
["app/assets/picture3.jpg"],
["app/assets/picture4.jpg"],
["app/assets/picture6.jpg"],
["app/assets/picture1.jpg"],
]
default_example = examples[0]
css = "h1 { text-align: center } .about { text-align: justify; padding-left: 10%; padding-right: 10%; }"
def segment_with_points(
image,
original_image,
input_size=1024,
better_quality=False,
withContours=True,
use_retina=True,
mask_random_color=True,
):
global global_points
global global_point_label
input_size = int(input_size)
w, h = image.size
scale = input_size / max(w, h)
new_w = int(w * scale)
new_h = int(h * scale)
image = image.resize((new_w, new_h))
scaled_points = np.array(
[[int(x * scale) for x in point] for point in global_points]
)
scaled_point_label = np.array(global_point_label)
if scaled_points.size == 0 and scaled_point_label.size == 0:
print("No points selected")
return image, image
nd_image = np.array(original_image.resize((new_w, new_h)))
predictor.set_image(nd_image)
masks, scores, logits = predictor.predict(
point_coords=scaled_points,
point_labels=scaled_point_label,
multimask_output=False,
)
results = format_results(masks, scores, logits, 0)
annotations, _ = point_prompt(
results, scaled_points, scaled_point_label, new_h, new_w
)
annotations = np.array([annotations])
fig = fast_process(
annotations=annotations,
image=image,
device=device,
scale=(1024 // input_size),
better_quality=better_quality,
mask_random_color=mask_random_color,
bbox=None,
use_retina=use_retina,
withContours=withContours,
)
global_points = []
global_point_label = []
# return fig, None
return fig, original_image.resize((new_w, new_h))
def get_points_with_draw(image, label, evt: gr.SelectData):
global global_points
global global_point_label
x, y = evt.index[0], evt.index[1]
point_radius, point_color = 15 * ((max(image.width, image.height)) / 1024), (255, 255, 0) if label == "Add Mask" else (
255,
0,
255,
)
global_points.append([x, y])
global_point_label.append(1 if label == "Add Mask" else 0)
# 创建一个可以在图像上绘图的对象
draw = ImageDraw.Draw(image)
draw.ellipse(
[(x - point_radius, y - point_radius), (x + point_radius, y + point_radius)],
fill=point_color,
)
return image
cond_img_e = gr.Image(label="Input", value=default_example[0], type="pil")
cond_img_p = gr.Image(label="Input with points", value=default_example[0], type="pil")
segm_img_e = gr.Image(label="Segmented Image", interactive=False, type="pil")
segm_img_p = gr.Image(
label="Segmented Image with points", interactive=True, type="pil"
)
global_points = []
global_point_label = []
input_size_slider = gr.components.Slider(
minimum=512,
maximum=1024,
value=1024,
step=64,
label="Input_size",
info="Our model was trained on a size of 1024",
)
with gr.Blocks(css=css, title="RepViT-SAM") as demo:
from PIL import Image
original_image = gr.State(value=Image.open(default_example[0]).convert('RGB'))
with gr.Row():
with gr.Column(scale=1):
# Title
gr.Markdown(title)
with gr.Tab("Point mode"):
# Images
with gr.Row(variant="panel"):
with gr.Column(scale=1):
cond_img_p.render()
with gr.Column(scale=1):
segm_img_p.render()
# Submit & Clear
with gr.Row():
with gr.Column():
with gr.Row():
add_or_remove = gr.Radio(
["Add Mask", "Remove Area"],
value="Add Mask",
)
with gr.Column():
segment_btn_p = gr.Button(
"Start segmenting!", variant="primary"
)
clear_btn_p = gr.Button("Restart", variant="secondary")
gr.Markdown("Try some of the examples below ⬇️")
gr.Examples(
examples=examples,
inputs=[cond_img_p],
fn=lambda x: x,
outputs=[original_image],
# fn=segment_with_points,
# cache_examples=True,
examples_per_page=4,
run_on_click=True
)
with gr.Column():
# Description
gr.Markdown(description_p)
cond_img_p.select(get_points_with_draw, [cond_img_p, add_or_remove], cond_img_p)
cond_img_p.upload(lambda x: x, inputs=[cond_img_p], outputs=[original_image])
# segment_btn_e.click(
# segment_everything,
# inputs=[
# cond_img_e,
# input_size_slider,
# mor_check,
# contour_check,
# retina_check,
# ],
# outputs=segm_img_e,
# )
segment_btn_p.click(
segment_with_points, inputs=[cond_img_p, original_image], outputs=[segm_img_p, cond_img_p]
)
def clear():
return None, None
def clear_text():
return None, None, None
# clear_btn_e.click(clear, outputs=[cond_img_e, segm_img_e])
clear_btn_p.click(clear, outputs=[cond_img_p, segm_img_p])
demo.queue()
demo.launch()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment