Commit 46260e34 authored by suily's avatar suily
Browse files

Initial commit

parents
Pipeline #2006 failed with stages
in 0 seconds
import torch
import torch.nn as nn
import torch.cuda.amp as amp
from src.core import register
import src.misc.dist as dist
__all__ = ['GradScaler']
GradScaler = register(amp.grad_scaler.GradScaler)
"""
reference:
https://github.com/ultralytics/yolov5/blob/master/utils/torch_utils.py#L404
by lyuwenyu
"""
import torch
import torch.nn as nn
import math
from copy import deepcopy
from src.core import register
import src.misc.dist as dist
__all__ = ['ModelEMA']
@register
class ModelEMA(object):
""" Model Exponential Moving Average from https://github.com/rwightman/pytorch-image-models
Keep a moving average of everything in the model state_dict (parameters and buffers).
This is intended to allow functionality like
https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
A smoothed version of the weights is necessary for some training schemes to perform well.
This class is sensitive where it is initialized in the sequence of model init,
GPU assignment and distributed training wrappers.
"""
def __init__(self, model: nn.Module, decay: float=0.9999, warmups: int=2000):
super().__init__()
# Create EMA
self.module = deepcopy(dist.de_parallel(model)).eval() # FP32 EMA
# if next(model.parameters()).device.type != 'cpu':
# self.module.half() # FP16 EMA
self.decay = decay
self.warmups = warmups
self.updates = 0 # number of EMA updates
# self.filter_no_grad = filter_no_grad
self.decay_fn = lambda x: decay * (1 - math.exp(-x / warmups)) # decay exponential ramp (to help early epochs)
for p in self.module.parameters():
p.requires_grad_(False)
def update(self, model: nn.Module):
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay_fn(self.updates)
msd = dist.de_parallel(model).state_dict()
for k, v in self.module.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
def to(self, *args, **kwargs):
self.module = self.module.to(*args, **kwargs)
return self
def update_attr(self, model, include=(), exclude=('process_group', 'reducer')):
# Update EMA attributes
self.copy_attr(self.module, model, include, exclude)
@staticmethod
def copy_attr(a, b, include=(), exclude=()):
# Copy attributes from b to a, options to only include [...] and to exclude [...]
for k, v in b.__dict__.items():
if (len(include) and k not in include) or k.startswith('_') or k in exclude:
continue
else:
setattr(a, k, v)
def state_dict(self, ):
return dict(module=self.module.state_dict(), updates=self.updates, warmups=self.warmups)
def load_state_dict(self, state):
self.module.load_state_dict(state['module'])
if 'updates' in state:
self.updates = state['updates']
def forwad(self, ):
raise RuntimeError('ema...')
def extra_repr(self) -> str:
return f'decay={self.decay}, warmups={self.warmups}'
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
"""Maintains moving averages of model parameters using an exponential decay.
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
is used to compute the EMA.
"""
def __init__(self, model, decay, device="cpu", use_buffers=True):
self.decay_fn = lambda x: decay * (1 - math.exp(-x / 2000))
def ema_avg(avg_model_param, model_param, num_averaged):
decay = self.decay_fn(num_averaged)
return decay * avg_model_param + (1 - decay) * model_param
super().__init__(model, device, ema_avg, use_buffers=use_buffers)
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
from src.core import register
__all__ = ['AdamW', 'SGD', 'Adam', 'MultiStepLR', 'CosineAnnealingLR', 'OneCycleLR', 'LambdaLR']
SGD = register(optim.SGD)
Adam = register(optim.Adam)
AdamW = register(optim.AdamW)
MultiStepLR = register(lr_scheduler.MultiStepLR)
CosineAnnealingLR = register(lr_scheduler.CosineAnnealingLR)
OneCycleLR = register(lr_scheduler.OneCycleLR)
LambdaLR = register(lr_scheduler.LambdaLR)
"""by lyuwenyu
"""
from .solver import BaseSolver
from .det_solver import DetSolver
from typing import Dict
TASKS :Dict[str, BaseSolver] = {
'detection': DetSolver,
}
\ No newline at end of file
"""
Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
https://github.com/facebookresearch/detr/blob/main/engine.py
by lyuwenyu
"""
import math
import os
import sys
import pathlib
from typing import Iterable
import torch
import torch.amp
from src.data import CocoEvaluator
from src.misc import (MetricLogger, SmoothedValue, reduce_dict)
def train_one_epoch(model: torch.nn.Module, criterion: torch.nn.Module,
data_loader: Iterable, optimizer: torch.optim.Optimizer,
device: torch.device, epoch: int, max_norm: float = 0, **kwargs):
model.train()
criterion.train()
metric_logger = MetricLogger(delimiter=" ")
metric_logger.add_meter('lr', SmoothedValue(window_size=1, fmt='{value:.6f}'))
# metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Epoch: [{}]'.format(epoch)
print_freq = kwargs.get('print_freq', 10)
ema = kwargs.get('ema', None)
scaler = kwargs.get('scaler', None)
for samples, targets in metric_logger.log_every(data_loader, print_freq, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
if scaler is not None:
with torch.autocast(device_type=str(device), cache_enabled=True):
outputs = model(samples, targets)
with torch.autocast(device_type=str(device), enabled=False):
loss_dict = criterion(outputs, targets)
loss = sum(loss_dict.values())
scaler.scale(loss).backward()
if max_norm > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
else:
outputs = model(samples, targets)
loss_dict = criterion(outputs, targets)
loss = sum(loss_dict.values())
optimizer.zero_grad()
loss.backward()
if max_norm > 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm)
optimizer.step()
# ema
if ema is not None:
ema.update(model)
loss_dict_reduced = reduce_dict(loss_dict)
loss_value = sum(loss_dict_reduced.values())
if not math.isfinite(loss_value):
print("Loss is {}, stopping training".format(loss_value))
print(loss_dict_reduced)
sys.exit(1)
metric_logger.update(loss=loss_value, **loss_dict_reduced)
metric_logger.update(lr=optimizer.param_groups[0]["lr"])
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
@torch.no_grad()
def evaluate(model: torch.nn.Module, criterion: torch.nn.Module, postprocessors, data_loader, base_ds, device, output_dir):
model.eval()
criterion.eval()
metric_logger = MetricLogger(delimiter=" ")
# metric_logger.add_meter('class_error', SmoothedValue(window_size=1, fmt='{value:.2f}'))
header = 'Test:'
# iou_types = tuple(k for k in ('segm', 'bbox') if k in postprocessors.keys())
iou_types = postprocessors.iou_types
coco_evaluator = CocoEvaluator(base_ds, iou_types)
# coco_evaluator.coco_eval[iou_types[0]].params.iouThrs = [0, 0.1, 0.5, 0.75]
panoptic_evaluator = None
# if 'panoptic' in postprocessors.keys():
# panoptic_evaluator = PanopticEvaluator(
# data_loader.dataset.ann_file,
# data_loader.dataset.ann_folder,
# output_dir=os.path.join(output_dir, "panoptic_eval"),
# )
for samples, targets in metric_logger.log_every(data_loader, 10, header):
samples = samples.to(device)
targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
# with torch.autocast(device_type=str(device)):
# outputs = model(samples)
outputs = model(samples)
# loss_dict = criterion(outputs, targets)
# weight_dict = criterion.weight_dict
# # reduce losses over all GPUs for logging purposes
# loss_dict_reduced = reduce_dict(loss_dict)
# loss_dict_reduced_scaled = {k: v * weight_dict[k]
# for k, v in loss_dict_reduced.items() if k in weight_dict}
# loss_dict_reduced_unscaled = {f'{k}_unscaled': v
# for k, v in loss_dict_reduced.items()}
# metric_logger.update(loss=sum(loss_dict_reduced_scaled.values()),
# **loss_dict_reduced_scaled,
# **loss_dict_reduced_unscaled)
# metric_logger.update(class_error=loss_dict_reduced['class_error'])
orig_target_sizes = torch.stack([t["orig_size"] for t in targets], dim=0)
results = postprocessors(outputs, orig_target_sizes)
# results = postprocessors(outputs, targets)
# if 'segm' in postprocessors.keys():
# target_sizes = torch.stack([t["size"] for t in targets], dim=0)
# results = postprocessors['segm'](results, outputs, orig_target_sizes, target_sizes)
res = {target['image_id'].item(): output for target, output in zip(targets, results)}
if coco_evaluator is not None:
coco_evaluator.update(res)
# if panoptic_evaluator is not None:
# res_pano = postprocessors["panoptic"](outputs, target_sizes, orig_target_sizes)
# for i, target in enumerate(targets):
# image_id = target["image_id"].item()
# file_name = f"{image_id:012d}.png"
# res_pano[i]["image_id"] = image_id
# res_pano[i]["file_name"] = file_name
# panoptic_evaluator.update(res_pano)
# gather the stats from all processes
metric_logger.synchronize_between_processes()
print("Averaged stats:", metric_logger)
if coco_evaluator is not None:
coco_evaluator.synchronize_between_processes()
if panoptic_evaluator is not None:
panoptic_evaluator.synchronize_between_processes()
# accumulate predictions from all images
if coco_evaluator is not None:
coco_evaluator.accumulate()
coco_evaluator.summarize()
# panoptic_res = None
# if panoptic_evaluator is not None:
# panoptic_res = panoptic_evaluator.summarize()
stats = {}
# stats = {k: meter.global_avg for k, meter in metric_logger.meters.items()}
if coco_evaluator is not None:
if 'bbox' in iou_types:
stats['coco_eval_bbox'] = coco_evaluator.coco_eval['bbox'].stats.tolist()
if 'segm' in iou_types:
stats['coco_eval_masks'] = coco_evaluator.coco_eval['segm'].stats.tolist()
# if panoptic_res is not None:
# stats['PQ_all'] = panoptic_res["All"]
# stats['PQ_th'] = panoptic_res["Things"]
# stats['PQ_st'] = panoptic_res["Stuff"]
return stats, coco_evaluator
'''
by lyuwenyu
'''
import time
import json
import datetime
import torch
from src.misc import dist
from src.data import get_coco_api_from_dataset
from .solver import BaseSolver
from .det_engine import train_one_epoch, evaluate
class DetSolver(BaseSolver):
def fit(self, ):
print("Start training")
self.train()
args = self.cfg
n_parameters = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
print('number of params:', n_parameters)
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
# best_stat = {'coco_eval_bbox': 0, 'coco_eval_masks': 0, 'epoch': -1, }
best_stat = {'epoch': -1, }
start_time = time.time()
for epoch in range(self.last_epoch + 1, args.epoches):
if dist.is_dist_available_and_initialized():
self.train_dataloader.sampler.set_epoch(epoch)
train_stats = train_one_epoch(
self.model, self.criterion, self.train_dataloader, self.optimizer, self.device, epoch,
args.clip_max_norm, print_freq=args.log_step, ema=self.ema, scaler=self.scaler)
self.lr_scheduler.step()
if self.output_dir:
checkpoint_paths = [self.output_dir / 'checkpoint.pth']
# extra checkpoint before LR drop and every 100 epochs
if (epoch + 1) % args.checkpoint_step == 0:
checkpoint_paths.append(self.output_dir / f'checkpoint{epoch:04}.pth')
for checkpoint_path in checkpoint_paths:
dist.save_on_master(self.state_dict(epoch), checkpoint_path)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(
module, self.criterion, self.postprocessor, self.val_dataloader, base_ds, self.device, self.output_dir
)
# TODO
for k in test_stats.keys():
if k in best_stat:
best_stat['epoch'] = epoch if test_stats[k][0] > best_stat[k] else best_stat['epoch']
best_stat[k] = max(best_stat[k], test_stats[k][0])
else:
best_stat['epoch'] = epoch
best_stat[k] = test_stats[k][0]
print('best_stat: ', best_stat)
log_stats = {**{f'train_{k}': v for k, v in train_stats.items()},
**{f'test_{k}': v for k, v in test_stats.items()},
'epoch': epoch,
'n_parameters': n_parameters}
if self.output_dir and dist.is_main_process():
with (self.output_dir / "log.txt").open("a") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
(self.output_dir / 'eval').mkdir(exist_ok=True)
if "bbox" in coco_evaluator.coco_eval:
filenames = ['latest.pth']
if epoch % 50 == 0:
filenames.append(f'{epoch:03}.pth')
for name in filenames:
torch.save(coco_evaluator.coco_eval["bbox"].eval,
self.output_dir / "eval" / name)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
def val(self, ):
self.eval()
base_ds = get_coco_api_from_dataset(self.val_dataloader.dataset)
module = self.ema.module if self.ema else self.model
test_stats, coco_evaluator = evaluate(module, self.criterion, self.postprocessor,
self.val_dataloader, base_ds, self.device, self.output_dir)
if self.output_dir:
dist.save_on_master(coco_evaluator.coco_eval["bbox"].eval, self.output_dir / "eval.pth")
return
"""by lyuwenyu
"""
import torch
import torch.nn as nn
from datetime import datetime
from pathlib import Path
from typing import Dict
from src.misc import dist
from src.core import BaseConfig
class BaseSolver(object):
def __init__(self, cfg: BaseConfig) -> None:
self.cfg = cfg
def setup(self, ):
'''Avoid instantiating unnecessary classes
'''
cfg = self.cfg
device = cfg.device
self.device = device
self.last_epoch = cfg.last_epoch
self.model = dist.warp_model(cfg.model.to(device), cfg.find_unused_parameters, cfg.sync_bn)
self.criterion = cfg.criterion.to(device)
self.postprocessor = cfg.postprocessor
# NOTE (lvwenyu): should load_tuning_state before ema instance building
if self.cfg.tuning:
print(f'Tuning checkpoint from {self.cfg.tuning}')
self.load_tuning_state(self.cfg.tuning)
self.scaler = cfg.scaler
self.ema = cfg.ema.to(device) if cfg.ema is not None else None
self.output_dir = Path(cfg.output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
def train(self, ):
self.setup()
self.optimizer = self.cfg.optimizer
self.lr_scheduler = self.cfg.lr_scheduler
# NOTE instantiating order
if self.cfg.resume:
print(f'Resume checkpoint from {self.cfg.resume}')
self.resume(self.cfg.resume)
self.train_dataloader = dist.warp_loader(self.cfg.train_dataloader, \
shuffle=self.cfg.train_dataloader.shuffle)
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
def eval(self, ):
self.setup()
self.val_dataloader = dist.warp_loader(self.cfg.val_dataloader, \
shuffle=self.cfg.val_dataloader.shuffle)
if self.cfg.resume:
print(f'resume from {self.cfg.resume}')
self.resume(self.cfg.resume)
def state_dict(self, last_epoch):
'''state dict
'''
state = {}
state['model'] = dist.de_parallel(self.model).state_dict()
state['date'] = datetime.now().isoformat()
# TODO
state['last_epoch'] = last_epoch
if self.optimizer is not None:
state['optimizer'] = self.optimizer.state_dict()
if self.lr_scheduler is not None:
state['lr_scheduler'] = self.lr_scheduler.state_dict()
# state['last_epoch'] = self.lr_scheduler.last_epoch
if self.ema is not None:
state['ema'] = self.ema.state_dict()
if self.scaler is not None:
state['scaler'] = self.scaler.state_dict()
return state
def load_state_dict(self, state):
'''load state dict
'''
# TODO
if getattr(self, 'last_epoch', None) and 'last_epoch' in state:
self.last_epoch = state['last_epoch']
print('Loading last_epoch')
if getattr(self, 'model', None) and 'model' in state:
if dist.is_parallel(self.model):
self.model.module.load_state_dict(state['model'])
else:
self.model.load_state_dict(state['model'])
print('Loading model.state_dict')
if getattr(self, 'ema', None) and 'ema' in state:
self.ema.load_state_dict(state['ema'])
print('Loading ema.state_dict')
if getattr(self, 'optimizer', None) and 'optimizer' in state:
self.optimizer.load_state_dict(state['optimizer'])
print('Loading optimizer.state_dict')
if getattr(self, 'lr_scheduler', None) and 'lr_scheduler' in state:
self.lr_scheduler.load_state_dict(state['lr_scheduler'])
print('Loading lr_scheduler.state_dict')
if getattr(self, 'scaler', None) and 'scaler' in state:
self.scaler.load_state_dict(state['scaler'])
print('Loading scaler.state_dict')
def save(self, path):
'''save state
'''
state = self.state_dict()
dist.save_on_master(state, path)
def resume(self, path):
'''load resume
'''
# for cuda:0 memory
state = torch.load(path, map_location='cpu')
self.load_state_dict(state)
def load_tuning_state(self, path,):
"""only load model for tuning and skip missed/dismatched keys
"""
if 'http' in path:
state = torch.hub.load_state_dict_from_url(path, map_location='cpu')
else:
state = torch.load(path, map_location='cpu')
module = dist.de_parallel(self.model)
# TODO hard code
if 'ema' in state:
stat, infos = self._matched_state(module.state_dict(), state['ema']['module'])
else:
stat, infos = self._matched_state(module.state_dict(), state['model'])
module.load_state_dict(stat, strict=False)
print(f'Load model.state_dict, {infos}')
@staticmethod
def _matched_state(state: Dict[str, torch.Tensor], params: Dict[str, torch.Tensor]):
missed_list = []
unmatched_list = []
matched_state = {}
for k, v in state.items():
if k in params:
if v.shape == params[k].shape:
matched_state[k] = params[k]
else:
unmatched_list.append(k)
else:
missed_list.append(k)
return matched_state, {'missed': missed_list, 'unmatched': unmatched_list}
def fit(self, ):
raise NotImplementedError('')
def val(self, ):
raise NotImplementedError('')
"""by lyuwenyu
"""
from .rtdetr import *
from .hybrid_encoder import *
from .rtdetr_decoder import *
from .rtdetr_postprocessor import *
from .rtdetr_criterion import *
from .matcher import *
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment