Commit b6c19984 authored by dengjb's avatar dengjb
Browse files

update

parents
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
# Based on: https://github.com/facebookresearch/detectron2/blob/master/detectron2/solver/build.py
import copy
import itertools
import math
import re
from enum import Enum
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Type, Union
import torch
from fastreid.config import CfgNode
from fastreid.utils.params import ContiguousParams
from . import lr_scheduler
_GradientClipperInput = Union[torch.Tensor, Iterable[torch.Tensor]]
_GradientClipper = Callable[[_GradientClipperInput], None]
class GradientClipType(Enum):
VALUE = "value"
NORM = "norm"
def _create_gradient_clipper(cfg: CfgNode) -> _GradientClipper:
"""
Creates gradient clipping closure to clip by value or by norm,
according to the provided config.
"""
cfg = copy.deepcopy(cfg)
def clip_grad_norm(p: _GradientClipperInput):
torch.nn.utils.clip_grad_norm_(p, cfg.CLIP_VALUE, cfg.NORM_TYPE)
def clip_grad_value(p: _GradientClipperInput):
torch.nn.utils.clip_grad_value_(p, cfg.CLIP_VALUE)
_GRADIENT_CLIP_TYPE_TO_CLIPPER = {
GradientClipType.VALUE: clip_grad_value,
GradientClipType.NORM: clip_grad_norm,
}
return _GRADIENT_CLIP_TYPE_TO_CLIPPER[GradientClipType(cfg.CLIP_TYPE)]
def _generate_optimizer_class_with_gradient_clipping(
optimizer: Type[torch.optim.Optimizer],
*,
per_param_clipper: Optional[_GradientClipper] = None,
global_clipper: Optional[_GradientClipper] = None,
) -> Type[torch.optim.Optimizer]:
"""
Dynamically creates a new type that inherits the type of a given instance
and overrides the `step` method to add gradient clipping
"""
assert (
per_param_clipper is None or global_clipper is None
), "Not allowed to use both per-parameter clipping and global clipping"
@torch.no_grad()
def optimizer_wgc_step(self, closure=None):
if per_param_clipper is not None:
for group in self.param_groups:
for p in group["params"]:
per_param_clipper(p)
else:
# global clipper for future use with detr
# (https://github.com/facebookresearch/detr/pull/287)
all_params = itertools.chain(*[g["params"] for g in self.param_groups])
global_clipper(all_params)
optimizer.step(self, closure)
OptimizerWithGradientClip = type(
optimizer.__name__ + "WithGradientClip",
(optimizer,),
{"step": optimizer_wgc_step},
)
return OptimizerWithGradientClip
def maybe_add_gradient_clipping(
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
) -> Type[torch.optim.Optimizer]:
"""
If gradient clipping is enabled through config options, wraps the existing
optimizer type to become a new dynamically created class OptimizerWithGradientClip
that inherits the given optimizer and overrides the `step` method to
include gradient clipping.
Args:
cfg: CfgNode, configuration options
optimizer: type. A subclass of torch.optim.Optimizer
Return:
type: either the input `optimizer` (if gradient clipping is disabled), or
a subclass of it with gradient clipping included in the `step` method.
"""
if not cfg.SOLVER.CLIP_GRADIENTS.ENABLED:
return optimizer
if isinstance(optimizer, torch.optim.Optimizer):
optimizer_type = type(optimizer)
else:
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
optimizer_type = optimizer
grad_clipper = _create_gradient_clipper(cfg.SOLVER.CLIP_GRADIENTS)
OptimizerWithGradientClip = _generate_optimizer_class_with_gradient_clipping(
optimizer_type, per_param_clipper=grad_clipper
)
if isinstance(optimizer, torch.optim.Optimizer):
optimizer.__class__ = OptimizerWithGradientClip # a bit hacky, not recommended
return optimizer
else:
return OptimizerWithGradientClip
def _generate_optimizer_class_with_freeze_layer(
optimizer: Type[torch.optim.Optimizer],
*,
freeze_iters: int = 0,
) -> Type[torch.optim.Optimizer]:
assert freeze_iters > 0, "No layers need to be frozen or freeze iterations is 0"
cnt = 0
@torch.no_grad()
def optimizer_wfl_step(self, closure=None):
nonlocal cnt
if cnt < freeze_iters:
cnt += 1
param_ref = []
grad_ref = []
for group in self.param_groups:
if group["freeze_status"] == "freeze":
for p in group["params"]:
if p.grad is not None:
param_ref.append(p)
grad_ref.append(p.grad)
p.grad = None
optimizer.step(self, closure)
for p, g in zip(param_ref, grad_ref):
p.grad = g
else:
optimizer.step(self, closure)
OptimizerWithFreezeLayer = type(
optimizer.__name__ + "WithFreezeLayer",
(optimizer,),
{"step": optimizer_wfl_step},
)
return OptimizerWithFreezeLayer
def maybe_add_freeze_layer(
cfg: CfgNode, optimizer: Type[torch.optim.Optimizer]
) -> Type[torch.optim.Optimizer]:
if len(cfg.MODEL.FREEZE_LAYERS) == 0 or cfg.SOLVER.FREEZE_ITERS <= 0:
return optimizer
if isinstance(optimizer, torch.optim.Optimizer):
optimizer_type = type(optimizer)
else:
assert issubclass(optimizer, torch.optim.Optimizer), optimizer
optimizer_type = optimizer
OptimizerWithFreezeLayer = _generate_optimizer_class_with_freeze_layer(
optimizer_type,
freeze_iters=cfg.SOLVER.FREEZE_ITERS
)
if isinstance(optimizer, torch.optim.Optimizer):
optimizer.__class__ = OptimizerWithFreezeLayer # a bit hacky, not recommended
return optimizer
else:
return OptimizerWithFreezeLayer
def build_optimizer(cfg, model, contiguous=True):
params = get_default_optimizer_params(
model,
base_lr=cfg.SOLVER.BASE_LR,
weight_decay=cfg.SOLVER.WEIGHT_DECAY,
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM,
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR,
heads_lr_factor=cfg.SOLVER.HEADS_LR_FACTOR,
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS,
freeze_layers=cfg.MODEL.FREEZE_LAYERS if cfg.SOLVER.FREEZE_ITERS > 0 else [],
)
if contiguous:
params = ContiguousParams(params)
solver_opt = cfg.SOLVER.OPT
if solver_opt == "SGD":
return maybe_add_freeze_layer(
cfg,
maybe_add_gradient_clipping(cfg, torch.optim.SGD)
)(
params.contiguous() if contiguous else params,
momentum=cfg.SOLVER.MOMENTUM,
nesterov=cfg.SOLVER.NESTEROV,
), params
else:
return maybe_add_freeze_layer(
cfg,
maybe_add_gradient_clipping(cfg, getattr(torch.optim, solver_opt))
)(params.contiguous() if contiguous else params), params
def get_default_optimizer_params(
model: torch.nn.Module,
base_lr: Optional[float] = None,
weight_decay: Optional[float] = None,
weight_decay_norm: Optional[float] = None,
bias_lr_factor: Optional[float] = 1.0,
heads_lr_factor: Optional[float] = 1.0,
weight_decay_bias: Optional[float] = None,
overrides: Optional[Dict[str, Dict[str, float]]] = None,
freeze_layers: Optional[list] = [],
):
"""
Get default param list for optimizer, with support for a few types of
overrides. If no overrides needed, this is equivalent to `model.parameters()`.
Args:
base_lr: lr for every group by default. Can be omitted to use the one in optimizer.
weight_decay: weight decay for every group by default. Can be omitted to use the one
in optimizer.
weight_decay_norm: override weight decay for params in normalization layers
bias_lr_factor: multiplier of lr for bias parameters.
heads_lr_factor: multiplier of lr for model.head parameters.
weight_decay_bias: override weight decay for bias parameters
overrides: if not `None`, provides values for optimizer hyperparameters
(LR, weight decay) for module parameters with a given name; e.g.
``{"embedding": {"lr": 0.01, "weight_decay": 0.1}}`` will set the LR and
weight decay values for all module parameters named `embedding`.
freeze_layers: layer names for freezing.
For common detection models, ``weight_decay_norm`` is the only option
needed to be set. ``bias_lr_factor,weight_decay_bias`` are legacy settings
from Detectron1 that are not found useful.
Example:
::
torch.optim.SGD(get_default_optimizer_params(model, weight_decay_norm=0),
lr=0.01, weight_decay=1e-4, momentum=0.9)
"""
if overrides is None:
overrides = {}
defaults = {}
if base_lr is not None:
defaults["lr"] = base_lr
if weight_decay is not None:
defaults["weight_decay"] = weight_decay
bias_overrides = {}
if bias_lr_factor is not None and bias_lr_factor != 1.0:
# NOTE: unlike Detectron v1, we now by default make bias hyperparameters
# exactly the same as regular weights.
if base_lr is None:
raise ValueError("bias_lr_factor requires base_lr")
bias_overrides["lr"] = base_lr * bias_lr_factor
if weight_decay_bias is not None:
bias_overrides["weight_decay"] = weight_decay_bias
if len(bias_overrides):
if "bias" in overrides:
raise ValueError("Conflicting overrides for 'bias'")
overrides["bias"] = bias_overrides
layer_names_pattern = [re.compile(name) for name in freeze_layers]
norm_module_types = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
# NaiveSyncBatchNorm inherits from BatchNorm2d
torch.nn.GroupNorm,
torch.nn.InstanceNorm1d,
torch.nn.InstanceNorm2d,
torch.nn.InstanceNorm3d,
torch.nn.LayerNorm,
torch.nn.LocalResponseNorm,
)
params: List[Dict[str, Any]] = []
memo: Set[torch.nn.parameter.Parameter] = set()
for module_name, module in model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
if not value.requires_grad:
continue
# Avoid duplicating parameters
if value in memo:
continue
memo.add(value)
hyperparams = copy.copy(defaults)
if isinstance(module, norm_module_types) and weight_decay_norm is not None:
hyperparams["weight_decay"] = weight_decay_norm
hyperparams.update(overrides.get(module_param_name, {}))
if module_name.split('.')[0] == "heads" and (heads_lr_factor is not None and heads_lr_factor != 1.0):
hyperparams["lr"] = hyperparams.get("lr", base_lr) * heads_lr_factor
name = module_name + '.' + module_param_name
freeze_status = "normal"
# Search freeze layer names, it must match from beginning, so use `match` not `search`
for pattern in layer_names_pattern:
if pattern.match(name) is not None:
freeze_status = "freeze"
break
params.append({"freeze_status": freeze_status, "params": [value], **hyperparams})
return params
def build_lr_scheduler(cfg, optimizer, iters_per_epoch):
max_epoch = cfg.SOLVER.MAX_EPOCH - max(
math.ceil(cfg.SOLVER.WARMUP_ITERS / iters_per_epoch), cfg.SOLVER.DELAY_EPOCHS)
scheduler_dict = {}
scheduler_args = {
"MultiStepLR": {
"optimizer": optimizer,
# multi-step lr scheduler options
"milestones": cfg.SOLVER.STEPS,
"gamma": cfg.SOLVER.GAMMA,
},
"CosineAnnealingLR": {
"optimizer": optimizer,
# cosine annealing lr scheduler options
"T_max": max_epoch,
"eta_min": cfg.SOLVER.ETA_MIN_LR,
},
}
scheduler_dict["lr_sched"] = getattr(lr_scheduler, cfg.SOLVER.SCHED)(
**scheduler_args[cfg.SOLVER.SCHED])
if cfg.SOLVER.WARMUP_ITERS > 0:
warmup_args = {
"optimizer": optimizer,
# warmup options
"warmup_factor": cfg.SOLVER.WARMUP_FACTOR,
"warmup_iters": cfg.SOLVER.WARMUP_ITERS,
"warmup_method": cfg.SOLVER.WARMUP_METHOD,
}
scheduler_dict["warmup_sched"] = lr_scheduler.WarmupLR(**warmup_args)
return scheduler_dict
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from typing import List
import torch
from torch.optim.lr_scheduler import *
class WarmupLR(torch.optim.lr_scheduler._LRScheduler):
def __init__(
self,
optimizer: torch.optim.Optimizer,
warmup_factor: float = 0.1,
warmup_iters: int = 1000,
warmup_method: str = "linear",
last_epoch: int = -1,
):
self.warmup_factor = warmup_factor
self.warmup_iters = warmup_iters
self.warmup_method = warmup_method
super().__init__(optimizer, last_epoch)
def get_lr(self) -> List[float]:
warmup_factor = _get_warmup_factor_at_epoch(
self.warmup_method, self.last_epoch, self.warmup_iters, self.warmup_factor
)
return [
base_lr * warmup_factor for base_lr in self.base_lrs
]
def _compute_values(self) -> List[float]:
# The new interface
return self.get_lr()
def _get_warmup_factor_at_epoch(
method: str, iter: int, warmup_iters: int, warmup_factor: float
) -> float:
"""
Return the learning rate warmup factor at a specific iteration.
See https://arxiv.org/abs/1706.02677 for more details.
Args:
method (str): warmup method; either "constant" or "linear".
iter (int): iter at which to calculate the warmup factor.
warmup_iters (int): the number of warmup epochs.
warmup_factor (float): the base warmup factor (the meaning changes according
to the method used).
Returns:
float: the effective warmup factor at the given iteration.
"""
if iter >= warmup_iters:
return 1.0
if method == "constant":
return warmup_factor
elif method == "linear":
alpha = iter / warmup_iters
return warmup_factor * (1 - alpha) + alpha
elif method == "exp":
return warmup_factor ** (1 - iter / warmup_iters)
else:
raise ValueError("Unknown warmup method: {}".format(method))
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
from .lamb import Lamb
from .swa import SWA
from .radam import RAdam
from torch.optim import *
####
# CODE TAKEN FROM https://github.com/mgrankin/over9000
####
import collections
import torch
from torch.optim.optimizer import Optimizer
from torch.utils.tensorboard import SummaryWriter
def log_lamb_rs(optimizer: Optimizer, event_writer: SummaryWriter, token_count: int):
"""Log a histogram of trust ratio scalars in across layers."""
results = collections.defaultdict(list)
for group in optimizer.param_groups:
for p in group['params']:
state = optimizer.state[p]
for i in ('weight_norm', 'adam_norm', 'trust_ratio'):
if i in state:
results[i].append(state[i])
for k, v in results.items():
event_writer.add_histogram(f'lamb/{k}', torch.tensor(v), token_count)
class Lamb(Optimizer):
r"""Implements Lamb algorithm.
It has been proposed in `Large Batch Optimization for Deep Learning: Training BERT in 76 minutes`_.
Arguments:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float, optional): learning rate (default: 1e-3)
betas (Tuple[float, float], optional): coefficients used for computing
running averages of gradient and its square (default: (0.9, 0.999))
eps (float, optional): term added to the denominator to improve
numerical stability (default: 1e-8)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
adam (bool, optional): always use trust ratio = 1, which turns this into
Adam. Useful for comparison purposes.
.. _Large Batch Optimization for Deep Learning: Training BERT in 76 minutes:
https://arxiv.org/abs/1904.00962
"""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-6,
weight_decay=0, adam=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= betas[0] < 1.0:
raise ValueError("Invalid beta parameter at index 0: {}".format(betas[0]))
if not 0.0 <= betas[1] < 1.0:
raise ValueError("Invalid beta parameter at index 1: {}".format(betas[1]))
defaults = dict(lr=lr, betas=betas, eps=eps,
weight_decay=weight_decay)
self.adam = adam
super(Lamb, self).__init__(params, defaults)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('Lamb does not support sparse gradients, consider SparseAdam instad.')
state = self.state[p]
# State initialization
if len(state) == 0:
state['step'] = 0
# Exponential moving average of gradient values
state['exp_avg'] = torch.zeros_like(p.data)
# Exponential moving average of squared gradient values
state['exp_avg_sq'] = torch.zeros_like(p.data)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
state['step'] += 1
# Decay the first and second moment running average coefficient
# m_t
exp_avg.mul_(beta1).add_(1 - beta1, grad)
# v_t
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
# Paper v3 does not use debiasing.
# bias_correction1 = 1 - beta1 ** state['step']
# bias_correction2 = 1 - beta2 ** state['step']
# Apply bias to lr to avoid broadcast.
step_size = group['lr'] # * math.sqrt(bias_correction2) / bias_correction1
weight_norm = p.data.pow(2).sum().sqrt().clamp(0, 10)
adam_step = exp_avg / exp_avg_sq.sqrt().add(group['eps'])
if group['weight_decay'] != 0:
adam_step.add_(group['weight_decay'], p.data)
adam_norm = adam_step.pow(2).sum().sqrt()
if weight_norm == 0 or adam_norm == 0:
trust_ratio = 1
else:
trust_ratio = weight_norm / adam_norm
state['weight_norm'] = weight_norm
state['adam_norm'] = adam_norm
state['trust_ratio'] = trust_ratio
if self.adam:
trust_ratio = 1
p.data.add_(-step_size * trust_ratio, adam_step)
return loss
import math
import torch
from torch.optim.optimizer import Optimizer
class RAdam(Optimizer):
def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
self.buffer = [[None, None, None] for ind in range(10)]
super(RAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(RAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
buffered = self.buffer[int(state['step'] % 10)]
if state['step'] == buffered[0]:
N_sma, step_size = buffered[1], buffered[2]
else:
buffered[0] = state['step']
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
buffered[1] = N_sma
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = group['lr'] * math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
else:
step_size = group['lr'] / (1 - beta1 ** state['step'])
buffered[2] = step_size
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
if N_sma >= 5:
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else:
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
class PlainRAdam(Optimizer):
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0):
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay)
super(PlainRAdam, self).__init__(params, defaults)
def __setstate__(self, state):
super(PlainRAdam, self).__setstate__(state)
def step(self, closure=None):
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data.float()
if grad.is_sparse:
raise RuntimeError('RAdam does not support sparse gradients')
p_data_fp32 = p.data.float()
state = self.state[p]
if len(state) == 0:
state['step'] = 0
state['exp_avg'] = torch.zeros_like(p_data_fp32)
state['exp_avg_sq'] = torch.zeros_like(p_data_fp32)
else:
state['exp_avg'] = state['exp_avg'].type_as(p_data_fp32)
state['exp_avg_sq'] = state['exp_avg_sq'].type_as(p_data_fp32)
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq']
beta1, beta2 = group['betas']
exp_avg_sq.mul_(beta2).addcmul_(1 - beta2, grad, grad)
exp_avg.mul_(beta1).add_(1 - beta1, grad)
state['step'] += 1
beta2_t = beta2 ** state['step']
N_sma_max = 2 / (1 - beta2) - 1
N_sma = N_sma_max - 2 * state['step'] * beta2_t / (1 - beta2_t)
if group['weight_decay'] != 0:
p_data_fp32.add_(-group['weight_decay'] * group['lr'], p_data_fp32)
# more conservative since it's an approximated value
if N_sma >= 5:
step_size = group['lr'] * math.sqrt(
(1 - beta2_t) * (N_sma - 4) / (N_sma_max - 4) * (N_sma - 2) / N_sma * N_sma_max / (
N_sma_max - 2)) / (1 - beta1 ** state['step'])
denom = exp_avg_sq.sqrt().add_(group['eps'])
p_data_fp32.addcdiv_(-step_size, exp_avg, denom)
else:
step_size = group['lr'] / (1 - beta1 ** state['step'])
p_data_fp32.add_(-step_size, exp_avg)
p.data.copy_(p_data_fp32)
return loss
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# based on:
# https://github.com/pytorch/contrib/blob/master/torchcontrib/optim/swa.py
import warnings
from collections import defaultdict
import torch
from torch.optim.optimizer import Optimizer
class SWA(Optimizer):
def __init__(self, optimizer, swa_freq=None, swa_lr_factor=None):
r"""Implements Stochastic Weight Averaging (SWA).
Stochastic Weight Averaging was proposed in `Averaging Weights Leads to
Wider Optima and Better Generalization`_ by Pavel Izmailov, Dmitrii
Podoprikhin, Timur Garipov, Dmitry Vetrov and Andrew Gordon Wilson
(UAI 2018).
SWA is implemented as a wrapper class taking optimizer instance as input
and applying SWA on top of that optimizer.
SWA can be used in two modes: automatic and manual. In the automatic
mode SWA running averages are automatically updated every
:attr:`swa_freq` steps after :attr:`swa_start` steps of optimization. If
:attr:`swa_lr` is provided, the learning rate of the optimizer is reset
to :attr:`swa_lr` at every step starting from :attr:`swa_start`. To use
SWA in automatic mode provide values for both :attr:`swa_start` and
:attr:`swa_freq` arguments.
Alternatively, in the manual mode, use :meth:`update_swa` or
:meth:`update_swa_group` methods to update the SWA running averages.
In the end of training use `swap_swa_sgd` method to set the optimized
variables to the computed averages.
Args:
swa_freq (int): number of steps between subsequent updates of
SWA running averages in automatic mode; if None, manual mode is
selected (default: None)
swa_lr (float): learning rate to use starting from step swa_start
in automatic mode; if None, learning rate is not changed
(default: None)
Examples:
>>> # automatic mode
>>> base_opt = torch.optim.SGD(model.parameters(), lr=0.1)
>>> opt = SWA(base_opt, swa_start=10, swa_freq=5, swa_lr=0.05)
>>> for _ in range(100):
>>> opt.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> opt.step()
>>> opt.swap_swa_param()
>>> # manual mode
>>> opt = SWA(base_opt)
>>> for i in range(100):
>>> opt.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> opt.step()
>>> if i > 10 and i % 5 == 0:
>>> opt.update_swa()
>>> opt.swap_swa_param()
.. note::
SWA does not support parameter-specific values of :attr:`swa_start`,
:attr:`swa_freq` or :attr:`swa_lr`. In automatic mode SWA uses the
same :attr:`swa_start`, :attr:`swa_freq` and :attr:`swa_lr` for all
parameter groups. If needed, use manual mode with
:meth:`update_swa_group` to use different update schedules for
different parameter groups.
.. note::
Call :meth:`swap_swa_sgd` in the end of training to use the computed
running averages.
.. note::
If you are using SWA to optimize the parameters of a Neural Network
containing Batch Normalization layers, you need to update the
:attr:`running_mean` and :attr:`running_var` statistics of the
Batch Normalization module. You can do so by using
`torchcontrib.optim.swa.bn_update` utility.
.. note::
See the blogpost
https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/
for an extended description of this SWA implementation.
.. note::
The repo https://github.com/izmailovpavel/contrib_swa_examples
contains examples of using this SWA implementation.
.. _Averaging Weights Leads to Wider Optima and Better Generalization:
https://arxiv.org/abs/1803.05407
.. _Improving Consistency-Based Semi-Supervised Learning with Weight
Averaging:
https://arxiv.org/abs/1806.05594
"""
self._auto_mode, (self.swa_freq,) = self._check_params(swa_freq)
self.swa_lr_factor = swa_lr_factor
if self._auto_mode:
if swa_freq < 1:
raise ValueError("Invalid swa_freq: {}".format(swa_freq))
else:
if self.swa_lr_factor is not None:
warnings.warn(
"Swa_freq is None, ignoring swa_lr")
# If not in auto mode make all swa parameters None
self.swa_lr_factor = None
self.swa_freq = None
if self.swa_lr_factor is not None and self.swa_lr_factor < 0:
raise ValueError("Invalid SWA learning rate factor: {}".format(swa_lr_factor))
self.optimizer = optimizer
self.defaults = self.optimizer.defaults
self.param_groups = self.optimizer.param_groups
self.state = defaultdict(dict)
self.opt_state = self.optimizer.state
for group in self.param_groups:
group['n_avg'] = 0
group['step_counter'] = 0
@staticmethod
def _check_params(swa_freq):
params = [swa_freq]
params_none = [param is None for param in params]
if not all(params_none) and any(params_none):
warnings.warn(
"Some of swa_start, swa_freq is None, ignoring other")
for i, param in enumerate(params):
if param is not None and not isinstance(param, int):
params[i] = int(param)
warnings.warn("Casting swa_start, swa_freq to int")
return not any(params_none), params
def reset_lr_to_swa(self):
for param_group in self.param_groups:
param_group['initial_lr'] = self.swa_lr_factor * param_group['lr']
def update_swa_group(self, group):
r"""Updates the SWA running averages for the given parameter group.
Arguments:
group (dict): Specifies for what parameter group SWA running
averages should be updated
Examples:
>>> # automatic mode
>>> base_opt = torch.optim.SGD([{'params': [x]},
>>> {'params': [y], 'lr': 1e-3}], lr=1e-2, momentum=0.9)
>>> opt = torchcontrib.optim.SWA(base_opt)
>>> for i in range(100):
>>> opt.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> opt.step()
>>> if i > 10 and i % 5 == 0:
>>> # Update SWA for the second parameter group
>>> opt.update_swa_group(opt.param_groups[1])
>>> opt.swap_swa_param()
"""
for p in group['params']:
param_state = self.state[p]
if 'swa_buffer' not in param_state:
param_state['swa_buffer'] = torch.zeros_like(p.data)
buf = param_state['swa_buffer']
virtual_decay = 1 / float(group["n_avg"] + 1)
diff = (p.data - buf) * virtual_decay
buf.add_(diff)
group["n_avg"] += 1
def update_swa(self):
r"""Updates the SWA running averages of all optimized parameters.
"""
for group in self.param_groups:
self.update_swa_group(group)
def swap_swa_param(self):
r"""Swaps the values of the optimized variables and swa buffers.
It's meant to be called in the end of training to use the collected
swa running averages. It can also be used to evaluate the running
averages during training; to continue training `swap_swa_sgd`
should be called again.
"""
for group in self.param_groups:
for p in group['params']:
param_state = self.state[p]
if 'swa_buffer' not in param_state:
# If swa wasn't applied we don't swap params
warnings.warn(
"SWA wasn't applied to param {}; skipping it".format(p))
continue
buf = param_state['swa_buffer']
tmp = torch.empty_like(p.data)
tmp.copy_(p.data)
p.data.copy_(buf)
buf.copy_(tmp)
def step(self, closure=None):
r"""Performs a single optimization step.
In automatic mode also updates SWA running averages.
"""
loss = self.optimizer.step(closure)
for group in self.param_groups:
group["step_counter"] += 1
steps = group["step_counter"]
if self._auto_mode:
if steps % self.swa_freq == 0:
self.update_swa_group(group)
return loss
def state_dict(self):
r"""Returns the state of SWA as a :class:`dict`.
It contains three entries:
* opt_state - a dict holding current optimization state of the base
optimizer. Its content differs between optimizer classes.
* swa_state - a dict containing current state of SWA. For each
optimized variable it contains swa_buffer keeping the running
average of the variable
* param_groups - a dict containing all parameter groups
"""
opt_state_dict = self.optimizer.state_dict()
swa_state = {(id(k) if isinstance(k, torch.Tensor) else k): v
for k, v in self.state.items()}
opt_state = opt_state_dict["state"]
param_groups = opt_state_dict["param_groups"]
return {"opt_state": opt_state, "swa_state": swa_state,
"param_groups": param_groups}
def load_state_dict(self, state_dict):
r"""Loads the optimizer state.
Args:
state_dict (dict): SWA optimizer state. Should be an object returned
from a call to `state_dict`.
"""
swa_state_dict = {"state": state_dict["swa_state"],
"param_groups": state_dict["param_groups"]}
opt_state_dict = {"state": state_dict["opt_state"],
"param_groups": state_dict["param_groups"]}
super(SWA, self).load_state_dict(swa_state_dict)
self.optimizer.load_state_dict(opt_state_dict)
self.opt_state = self.optimizer.state
def add_param_group(self, param_group):
r"""Add a param group to the :class:`Optimizer` s `param_groups`.
This can be useful when fine tuning a pre-trained network as frozen
layers can be made trainable and added to the :class:`Optimizer` as
training progresses.
Args:
param_group (dict): Specifies what Tensors should be optimized along
with group specific optimization options.
"""
param_group['n_avg'] = 0
param_group['step_counter'] = 0
self.optimizer.add_param_group(param_group)
# encoding: utf-8
"""
@author: sherlock
@contact: sherlockliao01@gmail.com
"""
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import copy
import logging
import os
from collections import defaultdict
from typing import Any
from typing import Optional, List, Dict, NamedTuple, Tuple, Iterable
import numpy as np
import torch
import torch.nn as nn
from termcolor import colored
from torch.nn.parallel import DistributedDataParallel, DataParallel
from fastreid.utils.file_io import PathManager
class _IncompatibleKeys(
NamedTuple(
# pyre-fixme[10]: Name `IncompatibleKeys` is used but not defined.
"IncompatibleKeys",
[
("missing_keys", List[str]),
("unexpected_keys", List[str]),
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
("incorrect_shapes", List[Tuple]),
],
)
):
pass
class Checkpointer(object):
"""
A checkpointer that can save/load model as well as extra checkpointable
objects.
"""
def __init__(
self,
model: nn.Module,
save_dir: str = "",
*,
save_to_disk: bool = True,
**checkpointables: object,
):
"""
Args:
model (nn.Module): model.
save_dir (str): a directory to save and find checkpoints.
save_to_disk (bool): if True, save checkpoint to disk, otherwise
disable saving for this checkpointer.
checkpointables (object): any checkpointable objects, i.e., objects
that have the `state_dict()` and `load_state_dict()` method. For
example, it can be used like
`Checkpointer(model, "dir", optimizer=optimizer)`.
"""
if isinstance(model, (DistributedDataParallel, DataParallel)):
model = model.module
self.model = model
self.checkpointables = copy.copy(checkpointables)
self.logger = logging.getLogger(__name__)
self.save_dir = save_dir
self.save_to_disk = save_to_disk
self.path_manager = PathManager
def save(self, name: str, **kwargs: Dict[str, str]):
"""
Dump model and checkpointables to a file.
Args:
name (str): name of the file.
kwargs (dict): extra arbitrary data to save.
"""
if not self.save_dir or not self.save_to_disk:
return
data = {}
data["model"] = self.model.state_dict()
for key, obj in self.checkpointables.items():
data[key] = obj.state_dict()
data.update(kwargs)
basename = "{}.pth".format(name)
save_file = os.path.join(self.save_dir, basename)
assert os.path.basename(save_file) == basename, basename
self.logger.info("Saving checkpoint to {}".format(save_file))
with PathManager.open(save_file, "wb") as f:
torch.save(data, f)
self.tag_last_checkpoint(basename)
def load(self, path: str, checkpointables: Optional[List[str]] = None) -> object:
"""
Load from the given checkpoint. When path points to network file, this
function has to be called on all ranks.
Args:
path (str): path or url to the checkpoint. If empty, will not load
anything.
checkpointables (list): List of checkpointable names to load. If not
specified (None), will load all the possible checkpointables.
Returns:
dict:
extra data loaded from the checkpoint that has not been
processed. For example, those saved with
:meth:`.save(**extra_data)`.
"""
if not path:
# no checkpoint provided
self.logger.info("No checkpoint found. Training model from scratch")
return {}
self.logger.info("Loading checkpoint from {}".format(path))
if not os.path.isfile(path):
path = self.path_manager.get_local_path(path)
assert os.path.isfile(path), "Checkpoint {} not found!".format(path)
checkpoint = self._load_file(path)
incompatible = self._load_model(checkpoint)
if (
incompatible is not None
): # handle some existing subclasses that returns None
self._log_incompatible_keys(incompatible)
for key in self.checkpointables if checkpointables is None else checkpointables:
if key in checkpoint: # pyre-ignore
self.logger.info("Loading {} from {}".format(key, path))
obj = self.checkpointables[key]
obj.load_state_dict(checkpoint.pop(key)) # pyre-ignore
# return any further checkpoint data
return checkpoint
def has_checkpoint(self):
"""
Returns:
bool: whether a checkpoint exists in the target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
return PathManager.exists(save_file)
def get_checkpoint_file(self):
"""
Returns:
str: The latest checkpoint file in target directory.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
try:
with PathManager.open(save_file, "r") as f:
last_saved = f.read().strip()
except IOError:
# if file doesn't exist, maybe because it has just been
# deleted by a separate process
return ""
return os.path.join(self.save_dir, last_saved)
def get_all_checkpoint_files(self):
"""
Returns:
list: All available checkpoint files (.pth files) in target
directory.
"""
all_model_checkpoints = [
os.path.join(self.save_dir, file)
for file in PathManager.ls(self.save_dir)
if PathManager.isfile(os.path.join(self.save_dir, file))
and file.endswith(".pth")
]
return all_model_checkpoints
def resume_or_load(self, path: str, *, resume: bool = True):
"""
If `resume` is True, this method attempts to resume from the last
checkpoint, if exists. Otherwise, load checkpoint from the given path.
This is useful when restarting an interrupted training job.
Args:
path (str): path to the checkpoint.
resume (bool): if True, resume from the last checkpoint if it exists.
Returns:
same as :meth:`load`.
"""
if resume and self.has_checkpoint():
path = self.get_checkpoint_file()
return self.load(path)
else:
return self.load(path, checkpointables=[])
def tag_last_checkpoint(self, last_filename_basename: str):
"""
Tag the last checkpoint.
Args:
last_filename_basename (str): the basename of the last filename.
"""
save_file = os.path.join(self.save_dir, "last_checkpoint")
with PathManager.open(save_file, "w") as f:
f.write(last_filename_basename)
def _load_file(self, f: str):
"""
Load a checkpoint file. Can be overwritten by subclasses to support
different formats.
Args:
f (str): a locally mounted file path.
Returns:
dict: with keys "model" and optionally others that are saved by
the checkpointer dict["model"] must be a dict which maps strings
to torch.Tensor or numpy arrays.
"""
return torch.load(f, map_location=torch.device("cpu"))
def _load_model(self, checkpoint: Any):
"""
Load weights from a checkpoint.
Args:
checkpoint (Any): checkpoint contains the weights.
"""
checkpoint_state_dict = checkpoint.pop("model")
self._convert_ndarray_to_tensor(checkpoint_state_dict)
# if the state_dict comes from a model that was wrapped in a
# DataParallel or DistributedDataParallel during serialization,
# remove the "module" prefix before performing the matching.
_strip_prefix_if_present(checkpoint_state_dict, "module.")
# work around https://github.com/pytorch/pytorch/issues/24139
model_state_dict = self.model.state_dict()
incorrect_shapes = []
for k in list(checkpoint_state_dict.keys()):
if k in model_state_dict:
shape_model = tuple(model_state_dict[k].shape)
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
if shape_model != shape_checkpoint:
incorrect_shapes.append((k, shape_checkpoint, shape_model))
checkpoint_state_dict.pop(k)
incompatible = self.model.load_state_dict(checkpoint_state_dict, strict=False)
return _IncompatibleKeys(
missing_keys=incompatible.missing_keys,
unexpected_keys=incompatible.unexpected_keys,
incorrect_shapes=incorrect_shapes,
)
def _log_incompatible_keys(self, incompatible: _IncompatibleKeys) -> None:
"""
Log information about the incompatible keys returned by ``_load_model``.
"""
for k, shape_checkpoint, shape_model in incompatible.incorrect_shapes:
self.logger.warning(
"Skip loading parameter '{}' to the model due to incompatible "
"shapes: {} in the checkpoint but {} in the "
"model! You might want to double check if this is expected.".format(
k, shape_checkpoint, shape_model
)
)
if incompatible.missing_keys:
missing_keys = _filter_reused_missing_keys(
self.model, incompatible.missing_keys
)
if missing_keys:
self.logger.info(get_missing_parameters_message(missing_keys))
if incompatible.unexpected_keys:
self.logger.info(
get_unexpected_parameters_message(incompatible.unexpected_keys)
)
def _convert_ndarray_to_tensor(self, state_dict: dict):
"""
In-place convert all numpy arrays in the state_dict to torch tensor.
Args:
state_dict (dict): a state-dict to be loaded to the model.
"""
# model could be an OrderedDict with _metadata attribute
# (as returned by Pytorch's state_dict()). We should preserve these
# properties.
for k in list(state_dict.keys()):
v = state_dict[k]
if not isinstance(v, np.ndarray) and not isinstance(
v, torch.Tensor
):
raise ValueError(
"Unsupported type found in checkpoint! {}: {}".format(
k, type(v)
)
)
if not isinstance(v, torch.Tensor):
state_dict[k] = torch.from_numpy(v)
class PeriodicCheckpointer:
"""
Save checkpoints periodically. When `.step(iteration)` is called, it will
execute `checkpointer.save` on the given checkpointer, if iteration is a
multiple of period or if `max_iter` is reached.
"""
def __init__(self, checkpointer: Any, period: int, max_epoch: int = None):
"""
Args:
checkpointer (Any): the checkpointer object used to save
checkpoints.
period (int): the period to save checkpoint.
max_epoch (int): maximum number of epochs. When it is reached,
a checkpoint named "model_final" will be saved.
"""
self.checkpointer = checkpointer
self.period = int(period)
self.max_epoch = max_epoch
self.best_metric = -1
def step(self, epoch: int, **kwargs: Any):
"""
Perform the appropriate action at the given iteration.
Args:
epoch (int): the current epoch, ranged in [0, max_epoch-1].
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
epoch = int(epoch)
additional_state = {"epoch": epoch}
additional_state.update(kwargs)
if (epoch + 1) % self.period == 0 and epoch < self.max_epoch - 1:
if additional_state["metric"] > self.best_metric:
self.checkpointer.save(
"model_best", **additional_state
)
self.best_metric = additional_state["metric"]
# Put it behind best model save to make last checkpoint valid
self.checkpointer.save(
"model_{:04d}".format(epoch), **additional_state
)
if epoch >= self.max_epoch - 1:
if additional_state["metric"] > self.best_metric:
self.checkpointer.save(
"model_best", **additional_state
)
self.checkpointer.save("model_final", **additional_state)
def save(self, name: str, **kwargs: Any):
"""
Same argument as :meth:`Checkpointer.save`.
Use this method to manually save checkpoints outside the schedule.
Args:
name (str): file name.
kwargs (Any): extra data to save, same as in
:meth:`Checkpointer.save`.
"""
self.checkpointer.save(name, **kwargs)
def _filter_reused_missing_keys(model: nn.Module, keys: List[str]) -> List[str]:
"""
Filter "missing keys" to not include keys that have been loaded with another name.
"""
keyset = set(keys)
param_to_names = defaultdict(set) # param -> names that points to it
for module_prefix, module in _named_modules_with_dup(model):
for name, param in list(module.named_parameters(recurse=False)) + list(
module.named_buffers(recurse=False) # pyre-ignore
):
full_name = (module_prefix + "." if module_prefix else "") + name
param_to_names[param].add(full_name)
for names in param_to_names.values():
# if one name appears missing but its alias exists, then this
# name is not considered missing
if any(n in keyset for n in names) and not all(n in keyset for n in names):
[keyset.remove(n) for n in names if n in keyset]
return list(keyset)
def get_missing_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the model but not found in a checkpoint.
Args:
keys (list[str]): List of keys that were not found in the checkpoint.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "Some model parameters or buffers are not found in the checkpoint:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "blue") for k, v in groups.items()
)
return msg
def get_unexpected_parameters_message(keys: List[str]) -> str:
"""
Get a logging-friendly message to report parameter names (keys) that are in
the checkpoint but not found in the model.
Args:
keys (list[str]): List of keys that were not found in the model.
Returns:
str: message.
"""
groups = _group_checkpoint_keys(keys)
msg = "The checkpoint state_dict contains keys that are not used by the model:\n"
msg += "\n".join(
" " + colored(k + _group_to_str(v), "magenta") for k, v in groups.items()
)
return msg
def _strip_prefix_if_present(state_dict: Dict[str, Any], prefix: str) -> None:
"""
Strip the prefix in metadata, if any.
Args:
state_dict (OrderedDict): a state-dict to be loaded to the model.
prefix (str): prefix.
"""
keys = sorted(state_dict.keys())
if not all(len(key) == 0 or key.startswith(prefix) for key in keys):
return
for key in keys:
newkey = key[len(prefix):]
state_dict[newkey] = state_dict.pop(key)
# also strip the prefix in metadata, if any..
try:
metadata = state_dict._metadata # pyre-ignore
except AttributeError:
pass
else:
for key in list(metadata.keys()):
# for the metadata dict, the key can be:
# '': for the DDP module, which we want to remove.
# 'module': for the actual model.
# 'module.xx.xx': for the rest.
if len(key) == 0:
continue
newkey = key[len(prefix):]
metadata[newkey] = metadata.pop(key)
def _group_checkpoint_keys(keys: List[str]) -> Dict[str, List[str]]:
"""
Group keys based on common prefixes. A prefix is the string up to the final
"." in each key.
Args:
keys (list[str]): list of parameter names, i.e. keys in the model
checkpoint dict.
Returns:
dict[list]: keys with common prefixes are grouped into lists.
"""
groups = defaultdict(list)
for key in keys:
pos = key.rfind(".")
if pos >= 0:
head, tail = key[:pos], [key[pos + 1:]]
else:
head, tail = key, []
groups[head].extend(tail)
return groups
def _group_to_str(group: List[str]) -> str:
"""
Format a group of parameter name suffixes into a loggable string.
Args:
group (list[str]): list of parameter name suffixes.
Returns:
str: formated string.
"""
if len(group) == 0:
return ""
if len(group) == 1:
return "." + group[0]
return ".{" + ", ".join(group) + "}"
def _named_modules_with_dup(
model: nn.Module, prefix: str = ""
) -> Iterable[Tuple[str, nn.Module]]:
"""
The same as `model.named_modules()`, except that it includes
duplicated modules that have more than one name.
"""
yield prefix, model
for name, module in model._modules.items(): # pyre-ignore
if module is None:
continue
submodule_prefix = prefix + ("." if prefix else "") + name
yield from _named_modules_with_dup(module, submodule_prefix)
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# based on
# https://github.com/facebookresearch/detectron2/blob/master/detectron2/utils/collect_env.py
import importlib
import os
import re
import subprocess
import sys
from collections import defaultdict
import PIL
import numpy as np
import torch
import torchvision
from tabulate import tabulate
__all__ = ["collect_env_info"]
def collect_torch_env():
try:
import torch.__config__
return torch.__config__.show()
except ImportError:
# compatible with older versions of pytorch
from torch.utils.collect_env import get_pretty_env_info
return get_pretty_env_info()
def get_env_module():
var_name = "FASTREID_ENV_MODULE"
return var_name, os.environ.get(var_name, "<not set>")
def detect_compute_compatibility(CUDA_HOME, so_file):
try:
cuobjdump = os.path.join(CUDA_HOME, "bin", "cuobjdump")
if os.path.isfile(cuobjdump):
output = subprocess.check_output(
"'{}' --list-elf '{}'".format(cuobjdump, so_file), shell=True
)
output = output.decode("utf-8").strip().split("\n")
sm = []
for line in output:
line = re.findall(r"\.sm_[0-9]*\.", line)[0]
sm.append(line.strip("."))
sm = sorted(set(sm))
return ", ".join(sm)
else:
return so_file + "; cannot find cuobjdump"
except Exception:
# unhandled failure
return so_file
def collect_env_info():
has_gpu = torch.cuda.is_available() # true for both CUDA & ROCM
torch_version = torch.__version__
# NOTE: the use of CUDA_HOME and ROCM_HOME requires the CUDA/ROCM build deps, though in
# theory detectron2 should be made runnable with only the corresponding runtimes
from torch.utils.cpp_extension import CUDA_HOME
has_rocm = False
if tuple(map(int, torch_version.split(".")[:2])) >= (1, 5):
from torch.utils.cpp_extension import ROCM_HOME
if (getattr(torch.version, "hip", None) is not None) and (ROCM_HOME is not None):
has_rocm = True
has_cuda = has_gpu and (not has_rocm)
data = []
data.append(("sys.platform", sys.platform))
data.append(("Python", sys.version.replace("\n", "")))
data.append(("numpy", np.__version__))
try:
import fastreid # noqa
data.append(
("fastreid", fastreid.__version__ + " @" + os.path.dirname(fastreid.__file__))
)
except ImportError:
data.append(("fastreid", "failed to import"))
data.append(get_env_module())
data.append(("PyTorch", torch_version + " @" + os.path.dirname(torch.__file__)))
data.append(("PyTorch debug build", torch.version.debug))
data.append(("GPU available", has_gpu))
if has_gpu:
devices = defaultdict(list)
for k in range(torch.cuda.device_count()):
devices[torch.cuda.get_device_name(k)].append(str(k))
for name, devids in devices.items():
data.append(("GPU " + ",".join(devids), name))
if has_rocm:
data.append(("ROCM_HOME", str(ROCM_HOME)))
else:
data.append(("CUDA_HOME", str(CUDA_HOME)))
cuda_arch_list = os.environ.get("TORCH_CUDA_ARCH_LIST", None)
if cuda_arch_list:
data.append(("TORCH_CUDA_ARCH_LIST", cuda_arch_list))
data.append(("Pillow", PIL.__version__))
try:
data.append(
(
"torchvision",
str(torchvision.__version__) + " @" + os.path.dirname(torchvision.__file__),
)
)
if has_cuda:
try:
torchvision_C = importlib.util.find_spec("torchvision._C").origin
msg = detect_compute_compatibility(CUDA_HOME, torchvision_C)
data.append(("torchvision arch flags", msg))
except ImportError:
data.append(("torchvision._C", "failed to find"))
except AttributeError:
data.append(("torchvision", "unknown"))
try:
import fvcore
data.append(("fvcore", fvcore.__version__))
except ImportError:
pass
try:
import cv2
data.append(("cv2", cv2.__version__))
except ImportError:
pass
env_str = tabulate(data) + "\n"
env_str += collect_torch_env()
return env_str
if __name__ == "__main__":
try:
import detectron2 # noqa
except ImportError:
print(collect_env_info())
else:
from fastreid.utils.collect_env import collect_env_info
print(collect_env_info())
"""
This file contains primitives for multi-gpu communication.
This is useful when doing distributed training.
"""
import functools
import logging
import numpy as np
import pickle
import torch
import torch.distributed as dist
_LOCAL_PROCESS_GROUP = None
"""
A torch process group which only includes processes that on the same machine as the current process.
This variable is set when processes are spawned by `launch()` in "engine/launch.py".
"""
def get_world_size() -> int:
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size()
def get_rank() -> int:
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
return dist.get_rank()
def get_local_rank() -> int:
"""
Returns:
The rank of the current process within the local (per-machine) process group.
"""
if not dist.is_available():
return 0
if not dist.is_initialized():
return 0
assert _LOCAL_PROCESS_GROUP is not None
return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
def get_local_size() -> int:
"""
Returns:
The size of the per-machine process group,
i.e. the number of processes per machine.
"""
if not dist.is_available():
return 1
if not dist.is_initialized():
return 1
return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
def is_main_process() -> bool:
return get_rank() == 0
def synchronize():
"""
Helper function to synchronize (barrier) among all processes when
using distributed training
"""
if not dist.is_available():
return
if not dist.is_initialized():
return
world_size = dist.get_world_size()
if world_size == 1:
return
dist.barrier()
@functools.lru_cache()
def _get_global_gloo_group():
"""
Return a process group based on gloo backend, containing all the ranks
The result is cached.
"""
if dist.get_backend() == "nccl":
return dist.new_group(backend="gloo")
else:
return dist.group.WORLD
def _serialize_to_tensor(data, group):
backend = dist.get_backend(group)
assert backend in ["gloo", "nccl"]
device = torch.device("cpu" if backend == "gloo" else "cuda")
buffer = pickle.dumps(data)
if len(buffer) > 1024 ** 3:
logger = logging.getLogger(__name__)
logger.warning(
"Rank {} trying to all-gather {:.2f} GB of data on device {}".format(
get_rank(), len(buffer) / (1024 ** 3), device
)
)
storage = torch.ByteStorage.from_buffer(buffer)
tensor = torch.ByteTensor(storage).to(device=device)
return tensor
def _pad_to_largest_tensor(tensor, group):
"""
Returns:
list[int]: size of the tensor, on each rank
Tensor: padded tensor that has the max size
"""
world_size = dist.get_world_size(group=group)
assert (
world_size >= 1
), "comm.gather/all_gather must be called from ranks within the given group!"
local_size = torch.tensor([tensor.numel()], dtype=torch.int64, device=tensor.device)
size_list = [
torch.zeros([1], dtype=torch.int64, device=tensor.device) for _ in range(world_size)
]
dist.all_gather(size_list, local_size, group=group)
size_list = [int(size.item()) for size in size_list]
max_size = max(size_list)
# we pad the tensor because torch all_gather does not support
# gathering tensors of different shapes
if local_size != max_size:
padding = torch.zeros((max_size - local_size,), dtype=torch.uint8, device=tensor.device)
tensor = torch.cat((tensor, padding), dim=0)
return size_list, tensor
def all_gather(data, group=None):
"""
Run all_gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: list of data gathered from each rank
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group) == 1:
return [data]
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
max_size = max(size_list)
# receiving Tensor from all ranks
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.all_gather(tensor_list, tensor, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
def gather(data, dst=0, group=None):
"""
Run gather on arbitrary picklable data (not necessarily tensors).
Args:
data: any picklable object
dst (int): destination rank
group: a torch process group. By default, will use a group which
contains all ranks on gloo backend.
Returns:
list[data]: on dst, a list of data gathered from each rank. Otherwise,
an empty list.
"""
if get_world_size() == 1:
return [data]
if group is None:
group = _get_global_gloo_group()
if dist.get_world_size(group=group) == 1:
return [data]
rank = dist.get_rank(group=group)
tensor = _serialize_to_tensor(data, group)
size_list, tensor = _pad_to_largest_tensor(tensor, group)
# receiving Tensor from all ranks
if rank == dst:
max_size = max(size_list)
tensor_list = [
torch.empty((max_size,), dtype=torch.uint8, device=tensor.device) for _ in size_list
]
dist.gather(tensor, tensor_list, dst=dst, group=group)
data_list = []
for size, tensor in zip(size_list, tensor_list):
buffer = tensor.cpu().numpy().tobytes()[:size]
data_list.append(pickle.loads(buffer))
return data_list
else:
dist.gather(tensor, [], dst=dst, group=group)
return []
def shared_random_seed():
"""
Returns:
int: a random number that is the same across all workers.
If workers need a shared RNG, they can use this shared seed to
create one.
All workers must call this function, otherwise it will deadlock.
"""
ints = np.random.randint(2 ** 31)
all_ints = all_gather(ints)
return all_ints[0]
def reduce_dict(input_dict, average=True):
"""
Reduce the values in the dictionary from all processes so that process with rank
0 has the reduced results.
Args:
input_dict (dict): inputs to be reduced. All the values must be scalar CUDA Tensor.
average (bool): whether to do average or sum
Returns:
a dict with the same keys as input_dict, after reduction.
"""
world_size = get_world_size()
if world_size < 2:
return input_dict
with torch.no_grad():
names = []
values = []
# sort the keys so that they are consistent across processes
for k in sorted(input_dict.keys()):
names.append(k)
values.append(input_dict[k])
values = torch.stack(values, dim=0)
dist.reduce(values, dst=0)
if dist.get_rank() == 0 and average:
# only main process gets accumulated, so only divide by
# world_size in this case
values /= world_size
reduced_dict = {k: v for k, v in zip(names, values)}
return reduced_dict
# encoding: utf-8
"""
@author: xingyu liao
@contact: sherlockliao01@gmail.com
"""
# Modified from: https://github.com/open-mmlab/OpenUnReID/blob/66bb2ae0b00575b80fbe8915f4d4f4739cc21206/openunreid/core/utils/compute_dist.py
import faiss
import numpy as np
import torch
import torch.nn.functional as F
from .faiss_utils import (
index_init_cpu,
index_init_gpu,
search_index_pytorch,
search_raw_array_pytorch,
)
__all__ = [
"build_dist",
"compute_jaccard_distance",
"compute_euclidean_distance",
"compute_cosine_distance",
]
@torch.no_grad()
def build_dist(feat_1: torch.Tensor, feat_2: torch.Tensor, metric: str = "euclidean", **kwargs) -> np.ndarray:
r"""Compute distance between two feature embeddings.
Args:
feat_1 (torch.Tensor): 2-D feature with batch dimension.
feat_2 (torch.Tensor): 2-D feature with batch dimension.
metric:
Returns:
numpy.ndarray: distance matrix.
"""
assert metric in ["cosine", "euclidean", "jaccard"], "Expected metrics are cosine, euclidean and jaccard, " \
"but got {}".format(metric)
if metric == "euclidean":
return compute_euclidean_distance(feat_1, feat_2)
elif metric == "cosine":
return compute_cosine_distance(feat_1, feat_2)
elif metric == "jaccard":
feat = torch.cat((feat_1, feat_2), dim=0)
dist = compute_jaccard_distance(feat, k1=kwargs["k1"], k2=kwargs["k2"], search_option=0)
return dist[: feat_1.size(0), feat_1.size(0):]
def k_reciprocal_neigh(initial_rank, i, k1):
forward_k_neigh_index = initial_rank[i, : k1 + 1]
backward_k_neigh_index = initial_rank[forward_k_neigh_index, : k1 + 1]
fi = np.where(backward_k_neigh_index == i)[0]
return forward_k_neigh_index[fi]
@torch.no_grad()
def compute_jaccard_distance(features, k1=20, k2=6, search_option=0, fp16=False):
if search_option < 3:
# torch.cuda.empty_cache()
features = features.cuda()
ngpus = faiss.get_num_gpus()
N = features.size(0)
mat_type = np.float16 if fp16 else np.float32
if search_option == 0:
# GPU + PyTorch CUDA Tensors (1)
res = faiss.StandardGpuResources()
res.setDefaultNullStreamAllDevices()
_, initial_rank = search_raw_array_pytorch(res, features, features, k1)
initial_rank = initial_rank.cpu().numpy()
elif search_option == 1:
# GPU + PyTorch CUDA Tensors (2)
res = faiss.StandardGpuResources()
index = faiss.GpuIndexFlatL2(res, features.size(-1))
index.add(features.cpu().numpy())
_, initial_rank = search_index_pytorch(index, features, k1)
res.syncDefaultStreamCurrentDevice()
initial_rank = initial_rank.cpu().numpy()
elif search_option == 2:
# GPU
index = index_init_gpu(ngpus, features.size(-1))
index.add(features.cpu().numpy())
_, initial_rank = index.search(features.cpu().numpy(), k1)
else:
# CPU
index = index_init_cpu(features.size(-1))
index.add(features.cpu().numpy())
_, initial_rank = index.search(features.cpu().numpy(), k1)
nn_k1 = []
nn_k1_half = []
for i in range(N):
nn_k1.append(k_reciprocal_neigh(initial_rank, i, k1))
nn_k1_half.append(k_reciprocal_neigh(initial_rank, i, int(np.around(k1 / 2))))
V = np.zeros((N, N), dtype=mat_type)
for i in range(N):
k_reciprocal_index = nn_k1[i]
k_reciprocal_expansion_index = k_reciprocal_index
for candidate in k_reciprocal_index:
candidate_k_reciprocal_index = nn_k1_half[candidate]
if len(
np.intersect1d(candidate_k_reciprocal_index, k_reciprocal_index)
) > 2 / 3 * len(candidate_k_reciprocal_index):
k_reciprocal_expansion_index = np.append(
k_reciprocal_expansion_index, candidate_k_reciprocal_index
)
k_reciprocal_expansion_index = np.unique(
k_reciprocal_expansion_index
) # element-wise unique
x = features[i].unsqueeze(0).contiguous()
y = features[k_reciprocal_expansion_index]
m, n = x.size(0), y.size(0)
dist = (
torch.pow(x, 2).sum(dim=1, keepdim=True).expand(m, n)
+ torch.pow(y, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
dist.addmm_(x, y.t(), beta=1, alpha=-2)
if fp16:
V[i, k_reciprocal_expansion_index] = (
F.softmax(-dist, dim=1).view(-1).cpu().numpy().astype(mat_type)
)
else:
V[i, k_reciprocal_expansion_index] = (
F.softmax(-dist, dim=1).view(-1).cpu().numpy()
)
del nn_k1, nn_k1_half, x, y
features = features.cpu()
if k2 != 1:
V_qe = np.zeros_like(V, dtype=mat_type)
for i in range(N):
V_qe[i, :] = np.mean(V[initial_rank[i, :k2], :], axis=0)
V = V_qe
del V_qe
del initial_rank
invIndex = []
for i in range(N):
invIndex.append(np.where(V[:, i] != 0)[0]) # len(invIndex)=all_num
jaccard_dist = np.zeros((N, N), dtype=mat_type)
for i in range(N):
temp_min = np.zeros((1, N), dtype=mat_type)
indNonZero = np.where(V[i, :] != 0)[0]
indImages = [invIndex[ind] for ind in indNonZero]
for j in range(len(indNonZero)):
temp_min[0, indImages[j]] = temp_min[0, indImages[j]] + np.minimum(
V[i, indNonZero[j]], V[indImages[j], indNonZero[j]]
)
jaccard_dist[i] = 1 - temp_min / (2 - temp_min)
del invIndex, V
pos_bool = jaccard_dist < 0
jaccard_dist[pos_bool] = 0.0
return jaccard_dist
@torch.no_grad()
def compute_euclidean_distance(features, others):
m, n = features.size(0), others.size(0)
dist_m = (
torch.pow(features, 2).sum(dim=1, keepdim=True).expand(m, n)
+ torch.pow(others, 2).sum(dim=1, keepdim=True).expand(n, m).t()
)
dist_m.addmm_(1, -2, features, others.t())
return dist_m.cpu().numpy()
@torch.no_grad()
def compute_cosine_distance(features, others):
"""Computes cosine distance.
Args:
features (torch.Tensor): 2-D feature matrix.
others (torch.Tensor): 2-D feature matrix.
Returns:
torch.Tensor: distance matrix.
"""
features = F.normalize(features, p=2, dim=1)
others = F.normalize(others, p=2, dim=1)
dist_m = 1 - torch.mm(features, others.t())
return dist_m.cpu().numpy()
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import importlib
import importlib.util
import logging
import numpy as np
import os
import random
import sys
from datetime import datetime
import torch
__all__ = ["seed_all_rng"]
TORCH_VERSION = tuple(int(x) for x in torch.__version__.split(".")[:2])
"""
PyTorch version as a tuple of 2 ints. Useful for comparison.
"""
def seed_all_rng(seed=None):
"""
Set the random seed for the RNG in torch, numpy and python.
Args:
seed (int): if None, will use a strong random seed.
"""
if seed is None:
seed = (
os.getpid()
+ int(datetime.now().strftime("%S%f"))
+ int.from_bytes(os.urandom(2), "big")
)
logger = logging.getLogger(__name__)
logger.info("Using a generated random seed {}".format(seed))
np.random.seed(seed)
torch.set_rng_state(torch.manual_seed(seed).get_state())
random.seed(seed)
# from https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path
def _import_file(module_name, file_path, make_importable=False):
spec = importlib.util.spec_from_file_location(module_name, file_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
if make_importable:
sys.modules[module_name] = module
return module
def _configure_libraries():
"""
Configurations for some libraries.
"""
# An environment option to disable `import cv2` globally,
# in case it leads to negative performance impact
disable_cv2 = int(os.environ.get("DETECTRON2_DISABLE_CV2", False))
if disable_cv2:
sys.modules["cv2"] = None
else:
# Disable opencl in opencv since its interaction with cuda often has negative effects
# This envvar is supported after OpenCV 3.4.0
os.environ["OPENCV_OPENCL_RUNTIME"] = "disabled"
try:
import cv2
if int(cv2.__version__.split(".")[0]) >= 3:
cv2.ocl.setUseOpenCL(False)
except ImportError:
pass
def get_version(module, digit=2):
return tuple(map(int, module.__version__.split(".")[:digit]))
# fmt: off
assert get_version(torch) >= (1, 4), "Requires torch>=1.4"
import yaml
assert get_version(yaml) >= (5, 1), "Requires pyyaml>=5.1"
# fmt: on
_ENV_SETUP_DONE = False
def setup_environment():
"""Perform environment setup work. The default setup is a no-op, but this
function allows the user to specify a Python source file or a module in
the $FASTREID_ENV_MODULE environment variable, that performs
custom setup work that may be necessary to their computing environment.
"""
global _ENV_SETUP_DONE
if _ENV_SETUP_DONE:
return
_ENV_SETUP_DONE = True
_configure_libraries()
custom_module_path = os.environ.get("FASTREID_ENV_MODULE")
if custom_module_path:
setup_custom_environment(custom_module_path)
else:
# The default setup is a no-op
pass
def setup_custom_environment(custom_module):
"""
Load custom environment setup by importing a Python source file or a
module, and run the setup function.
"""
if custom_module.endswith(".py"):
module = _import_file("fastreid.utils.env.custom_module", custom_module)
else:
module = importlib.import_module(custom_module)
assert hasattr(module, "setup_environment") and callable(module.setup_environment), (
"Custom environment module defined in {} does not have the "
"required callable attribute 'setup_environment'."
).format(custom_module)
module.setup_environment()
\ No newline at end of file
# Copyright (c) Facebook, Inc. and its affiliates.
import datetime
import json
import logging
import os
import time
from collections import defaultdict
from contextlib import contextmanager
import torch
from .file_io import PathManager
from .history_buffer import HistoryBuffer
__all__ = [
"get_event_storage",
"JSONWriter",
"TensorboardXWriter",
"CommonMetricPrinter",
"EventStorage",
]
_CURRENT_STORAGE_STACK = []
def get_event_storage():
"""
Returns:
The :class:`EventStorage` object that's currently being used.
Throws an error if no :class:`EventStorage` is currently enabled.
"""
assert len(
_CURRENT_STORAGE_STACK
), "get_event_storage() has to be called inside a 'with EventStorage(...)' context!"
return _CURRENT_STORAGE_STACK[-1]
class EventWriter:
"""
Base class for writers that obtain events from :class:`EventStorage` and process them.
"""
def write(self):
raise NotImplementedError
def close(self):
pass
class JSONWriter(EventWriter):
"""
Write scalars to a json file.
It saves scalars as one json per line (instead of a big json) for easy parsing.
Examples parsing such a json file:
::
$ cat metrics.json | jq -s '.[0:2]'
[
{
"data_time": 0.008433341979980469,
"iteration": 19,
"loss": 1.9228371381759644,
"loss_box_reg": 0.050025828182697296,
"loss_classifier": 0.5316952466964722,
"loss_mask": 0.7236229181289673,
"loss_rpn_box": 0.0856662318110466,
"loss_rpn_cls": 0.48198649287223816,
"lr": 0.007173333333333333,
"time": 0.25401854515075684
},
{
"data_time": 0.007216215133666992,
"iteration": 39,
"loss": 1.282649278640747,
"loss_box_reg": 0.06222952902317047,
"loss_classifier": 0.30682939291000366,
"loss_mask": 0.6970193982124329,
"loss_rpn_box": 0.038663312792778015,
"loss_rpn_cls": 0.1471673548221588,
"lr": 0.007706666666666667,
"time": 0.2490077018737793
}
]
$ cat metrics.json | jq '.loss_mask'
0.7126231789588928
0.689423680305481
0.6776131987571716
...
"""
def __init__(self, json_file, window_size=20):
"""
Args:
json_file (str): path to the json file. New data will be appended if the file exists.
window_size (int): the window size of median smoothing for the scalars whose
`smoothing_hint` are True.
"""
self._file_handle = PathManager.open(json_file, "a")
self._window_size = window_size
self._last_write = -1
def write(self):
storage = get_event_storage()
to_save = defaultdict(dict)
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
# keep scalars that have not been written
if iter <= self._last_write:
continue
to_save[iter][k] = v
if len(to_save):
all_iters = sorted(to_save.keys())
self._last_write = max(all_iters)
for itr, scalars_per_iter in to_save.items():
scalars_per_iter["iteration"] = itr
self._file_handle.write(json.dumps(scalars_per_iter, sort_keys=True) + "\n")
self._file_handle.flush()
try:
os.fsync(self._file_handle.fileno())
except AttributeError:
pass
def close(self):
self._file_handle.close()
class TensorboardXWriter(EventWriter):
"""
Write all scalars to a tensorboard file.
"""
def __init__(self, log_dir: str, window_size: int = 20, **kwargs):
"""
Args:
log_dir (str): the directory to save the output events
window_size (int): the scalars will be median-smoothed by this window size
kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)`
"""
self._window_size = window_size
from torch.utils.tensorboard import SummaryWriter
self._writer = SummaryWriter(log_dir, **kwargs)
self._last_write = -1
def write(self):
storage = get_event_storage()
new_last_write = self._last_write
for k, (v, iter) in storage.latest_with_smoothing_hint(self._window_size).items():
if iter > self._last_write:
self._writer.add_scalar(k, v, iter)
new_last_write = max(new_last_write, iter)
self._last_write = new_last_write
# storage.put_{image,histogram} is only meant to be used by
# tensorboard writer. So we access its internal fields directly from here.
if len(storage._vis_data) >= 1:
for img_name, img, step_num in storage._vis_data:
self._writer.add_image(img_name, img, step_num)
# Storage stores all image data and rely on this writer to clear them.
# As a result it assumes only one writer will use its image data.
# An alternative design is to let storage store limited recent
# data (e.g. only the most recent image) that all writers can access.
# In that case a writer may not see all image data if its period is long.
storage.clear_images()
if len(storage._histograms) >= 1:
for params in storage._histograms:
self._writer.add_histogram_raw(**params)
storage.clear_histograms()
def close(self):
if hasattr(self, "_writer"): # doesn't exist when the code fails at import
self._writer.close()
class CommonMetricPrinter(EventWriter):
"""
Print **common** metrics to the terminal, including
iteration time, ETA, memory, all losses, and the learning rate.
It also applies smoothing using a window of 20 elements.
It's meant to print common metrics in common ways.
To print something in more customized ways, please implement a similar printer by yourself.
"""
def __init__(self, max_iter):
"""
Args:
max_iter (int): the maximum number of iterations to train.
Used to compute ETA.
"""
self.logger = logging.getLogger(__name__)
self._max_iter = max_iter
self._last_write = None
def write(self):
storage = get_event_storage()
iteration = storage.iter
epoch = storage.epoch
if iteration == self._max_iter:
# This hook only reports training progress (loss, ETA, etc) but not other data,
# therefore do not write anything after training succeeds, even if this method
# is called.
return
try:
data_time = storage.history("data_time").avg(20)
except KeyError:
# they may not exist in the first few iterations (due to warmup)
# or when SimpleTrainer is not used
data_time = None
eta_string = None
try:
iter_time = storage.history("time").global_avg()
eta_seconds = storage.history("time").median(1000) * (self._max_iter - iteration - 1)
storage.put_scalar("eta_seconds", eta_seconds, smoothing_hint=False)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
except KeyError:
iter_time = None
# estimate eta on our own - more noisy
if self._last_write is not None:
estimate_iter_time = (time.perf_counter() - self._last_write[1]) / (
iteration - self._last_write[0]
)
eta_seconds = estimate_iter_time * (self._max_iter - iteration - 1)
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
self._last_write = (iteration, time.perf_counter())
try:
lr = "{:.2e}".format(storage.history("lr").latest())
except KeyError:
lr = "N/A"
if torch.cuda.is_available():
max_mem_mb = torch.cuda.max_memory_allocated() / 1024.0 / 1024.0
else:
max_mem_mb = None
# NOTE: max_mem is parsed by grep in "dev/parse_results.sh"
self.logger.info(
" {eta}epoch/iter: {epoch}/{iter} {losses} {time}{data_time}lr: {lr} {memory}".format(
eta=f"eta: {eta_string} " if eta_string else "",
epoch=epoch,
iter=iteration,
losses=" ".join(
[
"{}: {:.4g}".format(k, v.median(200))
for k, v in storage.histories().items()
if "loss" in k
]
),
time="time: {:.4f} ".format(iter_time) if iter_time is not None else "",
data_time="data_time: {:.4f} ".format(data_time) if data_time is not None else "",
lr=lr,
memory="max_mem: {:.0f}M".format(max_mem_mb) if max_mem_mb is not None else "",
)
)
class EventStorage:
"""
The user-facing class that provides metric storage functionalities.
In the future we may add support for storing / logging other types of data if needed.
"""
def __init__(self, start_iter=0):
"""
Args:
start_iter (int): the iteration number to start with
"""
self._history = defaultdict(HistoryBuffer)
self._smoothing_hints = {}
self._latest_scalars = {}
self._iter = start_iter
self._current_prefix = ""
self._vis_data = []
self._histograms = []
def put_image(self, img_name, img_tensor):
"""
Add an `img_tensor` associated with `img_name`, to be shown on
tensorboard.
Args:
img_name (str): The name of the image to put into tensorboard.
img_tensor (torch.Tensor or numpy.array): An `uint8` or `float`
Tensor of shape `[channel, height, width]` where `channel` is
3. The image format should be RGB. The elements in img_tensor
can either have values in [0, 1] (float32) or [0, 255] (uint8).
The `img_tensor` will be visualized in tensorboard.
"""
self._vis_data.append((img_name, img_tensor, self._iter))
def put_scalar(self, name, value, smoothing_hint=True):
"""
Add a scalar `value` to the `HistoryBuffer` associated with `name`.
Args:
smoothing_hint (bool): a 'hint' on whether this scalar is noisy and should be
smoothed when logged. The hint will be accessible through
:meth:`EventStorage.smoothing_hints`. A writer may ignore the hint
and apply custom smoothing rule.
It defaults to True because most scalars we save need to be smoothed to
provide any useful signal.
"""
name = self._current_prefix + name
history = self._history[name]
value = float(value)
history.update(value, self._iter)
self._latest_scalars[name] = (value, self._iter)
existing_hint = self._smoothing_hints.get(name)
if existing_hint is not None:
assert (
existing_hint == smoothing_hint
), "Scalar {} was put with a different smoothing_hint!".format(name)
else:
self._smoothing_hints[name] = smoothing_hint
def put_scalars(self, *, smoothing_hint=True, **kwargs):
"""
Put multiple scalars from keyword arguments.
Examples:
storage.put_scalars(loss=my_loss, accuracy=my_accuracy, smoothing_hint=True)
"""
for k, v in kwargs.items():
self.put_scalar(k, v, smoothing_hint=smoothing_hint)
def put_histogram(self, hist_name, hist_tensor, bins=1000):
"""
Create a histogram from a tensor.
Args:
hist_name (str): The name of the histogram to put into tensorboard.
hist_tensor (torch.Tensor): A Tensor of arbitrary shape to be converted
into a histogram.
bins (int): Number of histogram bins.
"""
ht_min, ht_max = hist_tensor.min().item(), hist_tensor.max().item()
# Create a histogram with PyTorch
hist_counts = torch.histc(hist_tensor, bins=bins)
hist_edges = torch.linspace(start=ht_min, end=ht_max, steps=bins + 1, dtype=torch.float32)
# Parameter for the add_histogram_raw function of SummaryWriter
hist_params = dict(
tag=hist_name,
min=ht_min,
max=ht_max,
num=len(hist_tensor),
sum=float(hist_tensor.sum()),
sum_squares=float(torch.sum(hist_tensor ** 2)),
bucket_limits=hist_edges[1:].tolist(),
bucket_counts=hist_counts.tolist(),
global_step=self._iter,
)
self._histograms.append(hist_params)
def history(self, name):
"""
Returns:
HistoryBuffer: the scalar history for name
"""
ret = self._history.get(name, None)
if ret is None:
raise KeyError("No history metric available for {}!".format(name))
return ret
def histories(self):
"""
Returns:
dict[name -> HistoryBuffer]: the HistoryBuffer for all scalars
"""
return self._history
def latest(self):
"""
Returns:
dict[str -> (float, int)]: mapping from the name of each scalar to the most
recent value and the iteration number its added.
"""
return self._latest_scalars
def latest_with_smoothing_hint(self, window_size=20):
"""
Similar to :meth:`latest`, but the returned values
are either the un-smoothed original latest value,
or a median of the given window_size,
depend on whether the smoothing_hint is True.
This provides a default behavior that other writers can use.
"""
result = {}
for k, (v, itr) in self._latest_scalars.items():
result[k] = (
self._history[k].median(window_size) if self._smoothing_hints[k] else v,
itr,
)
return result
def smoothing_hints(self):
"""
Returns:
dict[name -> bool]: the user-provided hint on whether the scalar
is noisy and needs smoothing.
"""
return self._smoothing_hints
def step(self):
"""
User should either: (1) Call this function to increment storage.iter when needed. Or
(2) Set `storage.iter` to the correct iteration number before each iteration.
The storage will then be able to associate the new data with an iteration number.
"""
self._iter += 1
@property
def iter(self):
"""
Returns:
int: The current iteration number. When used together with a trainer,
this is ensured to be the same as trainer.iter.
"""
return self._iter
@iter.setter
def iter(self, val):
self._iter = int(val)
@property
def iteration(self):
# for backward compatibility
return self._iter
def __enter__(self):
_CURRENT_STORAGE_STACK.append(self)
return self
def __exit__(self, exc_type, exc_val, exc_tb):
assert _CURRENT_STORAGE_STACK[-1] == self
_CURRENT_STORAGE_STACK.pop()
@contextmanager
def name_scope(self, name):
"""
Yields:
A context within which all the events added to this storage
will be prefixed by the name scope.
"""
old_prefix = self._current_prefix
self._current_prefix = name.rstrip("/") + "/"
yield
self._current_prefix = old_prefix
def clear_images(self):
"""
Delete all the stored images for visualization. This should be called
after images are written to tensorboard.
"""
self._vis_data = []
def clear_histograms(self):
"""
Delete all the stored histograms for visualization.
This should be called after histograms are written to tensorboard.
"""
self._histograms = []
\ No newline at end of file
# encoding: utf-8
# copy from: https://github.com/open-mmlab/OpenUnReID/blob/66bb2ae0b00575b80fbe8915f4d4f4739cc21206/openunreid/core/utils/faiss_utils.py
import faiss
import torch
def swig_ptr_from_FloatTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.float32
return faiss.cast_integer_to_float_ptr(
x.storage().data_ptr() + x.storage_offset() * 4
)
def swig_ptr_from_LongTensor(x):
assert x.is_contiguous()
assert x.dtype == torch.int64, "dtype=%s" % x.dtype
return faiss.cast_integer_to_long_ptr(
x.storage().data_ptr() + x.storage_offset() * 8
)
def search_index_pytorch(index, x, k, D=None, I=None):
"""call the search function of an index with pytorch tensor I/O (CPU
and GPU supported)"""
assert x.is_contiguous()
n, d = x.size()
assert d == index.d
if D is None:
D = torch.empty((n, k), dtype=torch.float32, device=x.device)
else:
assert D.size() == (n, k)
if I is None:
I = torch.empty((n, k), dtype=torch.int64, device=x.device)
else:
assert I.size() == (n, k)
torch.cuda.synchronize()
xptr = swig_ptr_from_FloatTensor(x)
Iptr = swig_ptr_from_LongTensor(I)
Dptr = swig_ptr_from_FloatTensor(D)
index.search_c(n, xptr, k, Dptr, Iptr)
torch.cuda.synchronize()
return D, I
def search_raw_array_pytorch(res, xb, xq, k, D=None, I=None, metric=faiss.METRIC_L2):
assert xb.device == xq.device
nq, d = xq.size()
if xq.is_contiguous():
xq_row_major = True
elif xq.t().is_contiguous():
xq = xq.t() # I initially wrote xq:t(), Lua is still haunting me :-)
xq_row_major = False
else:
raise TypeError("matrix should be row or column-major")
xq_ptr = swig_ptr_from_FloatTensor(xq)
nb, d2 = xb.size()
assert d2 == d
if xb.is_contiguous():
xb_row_major = True
elif xb.t().is_contiguous():
xb = xb.t()
xb_row_major = False
else:
raise TypeError("matrix should be row or column-major")
xb_ptr = swig_ptr_from_FloatTensor(xb)
if D is None:
D = torch.empty(nq, k, device=xb.device, dtype=torch.float32)
else:
assert D.shape == (nq, k)
assert D.device == xb.device
if I is None:
I = torch.empty(nq, k, device=xb.device, dtype=torch.int64)
else:
assert I.shape == (nq, k)
assert I.device == xb.device
D_ptr = swig_ptr_from_FloatTensor(D)
I_ptr = swig_ptr_from_LongTensor(I)
faiss.bruteForceKnn(
res,
metric,
xb_ptr,
xb_row_major,
nb,
xq_ptr,
xq_row_major,
nq,
d,
k,
D_ptr,
I_ptr,
)
return D, I
def index_init_gpu(ngpus, feat_dim):
flat_config = []
for i in range(ngpus):
cfg = faiss.GpuIndexFlatConfig()
cfg.useFloat16 = False
cfg.device = i
flat_config.append(cfg)
res = [faiss.StandardGpuResources() for i in range(ngpus)]
indexes = [
faiss.GpuIndexFlatL2(res[i], feat_dim, flat_config[i]) for i in range(ngpus)
]
index = faiss.IndexShards(feat_dim)
for sub_index in indexes:
index.add_shard(sub_index)
index.reset()
return index
def index_init_cpu(feat_dim):
return faiss.IndexFlatL2(feat_dim)
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import errno
import logging
import os
import shutil
from collections import OrderedDict
from typing import (
IO,
Any,
Callable,
Dict,
List,
MutableMapping,
Optional,
Union,
)
__all__ = ["PathManager", "get_cache_dir"]
def get_cache_dir(cache_dir: Optional[str] = None) -> str:
"""
Returns a default directory to cache static files
(usually downloaded from Internet), if None is provided.
Args:
cache_dir (None or str): if not None, will be returned as is.
If None, returns the default cache directory as:
1) $FVCORE_CACHE, if set
2) otherwise ~/.torch/fvcore_cache
"""
if cache_dir is None:
cache_dir = os.path.expanduser(
os.getenv("FVCORE_CACHE", "~/.torch/fvcore_cache")
)
return cache_dir
class PathHandler:
"""
PathHandler is a base class that defines common I/O functionality for a URI
protocol. It routes I/O for a generic URI which may look like "protocol://*"
or a canonical filepath "/foo/bar/baz".
"""
_strict_kwargs_check = True
def _check_kwargs(self, kwargs: Dict[str, Any]) -> None:
"""
Checks if the given arguments are empty. Throws a ValueError if strict
kwargs checking is enabled and args are non-empty. If strict kwargs
checking is disabled, only a warning is logged.
Args:
kwargs (Dict[str, Any])
"""
if self._strict_kwargs_check:
if len(kwargs) > 0:
raise ValueError("Unused arguments: {}".format(kwargs))
else:
logger = logging.getLogger(__name__)
for k, v in kwargs.items():
logger.warning(
"[PathManager] {}={} argument ignored".format(k, v)
)
def _get_supported_prefixes(self) -> List[str]:
"""
Returns:
List[str]: the list of URI prefixes this PathHandler can support
"""
raise NotImplementedError()
def _get_local_path(self, path: str, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk. In this case, this function is meant to be
used with read-only resources.
Args:
path (str): A URI supported by this PathHandler
Returns:
local_path (str): a file path which exists on the local file system
"""
raise NotImplementedError()
def _open(
self, path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
raise NotImplementedError()
def _copy(
self,
src_path: str,
dst_path: str,
overwrite: bool = False,
**kwargs: Any,
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
raise NotImplementedError()
def _exists(self, path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
raise NotImplementedError()
def _isfile(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
raise NotImplementedError()
def _isdir(self, path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
raise NotImplementedError()
def _ls(self, path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
raise NotImplementedError()
def _mkdirs(self, path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
def _rm(self, path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
raise NotImplementedError()
class NativePathHandler(PathHandler):
"""
Handles paths that can be accessed using Python native system calls. This
handler uses `open()` and `os.*` calls on the given path.
"""
def _get_local_path(self, path: str, **kwargs: Any) -> str:
self._check_kwargs(kwargs)
return path
def _open(
self,
path: str,
mode: str = "r",
buffering: int = -1,
encoding: Optional[str] = None,
errors: Optional[str] = None,
newline: Optional[str] = None,
closefd: bool = True,
opener: Optional[Callable] = None,
**kwargs: Any,
) -> Union[IO[str], IO[bytes]]:
"""
Open a path.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy works as follows:
* Binary files are buffered in fixed-size chunks; the size of
the buffer is chosen using a heuristic trying to determine the
underlying device’s “block size” and falling back on
io.DEFAULT_BUFFER_SIZE. On many systems, the buffer will
typically be 4096 or 8192 bytes long.
encoding (Optional[str]): the name of the encoding used to decode or
encode the file. This should only be used in text mode.
errors (Optional[str]): an optional string that specifies how encoding
and decoding errors are to be handled. This cannot be used in binary
mode.
newline (Optional[str]): controls how universal newlines mode works
(it only applies to text mode). It can be None, '', '\n', '\r',
and '\r\n'.
closefd (bool): If closefd is False and a file descriptor rather than
a filename was given, the underlying file descriptor will be kept
open when the file is closed. If a filename is given closefd must
be True (the default) otherwise an error will be raised.
opener (Optional[Callable]): A custom opener can be used by passing
a callable as opener. The underlying file descriptor for the file
object is then obtained by calling opener with (file, flags).
opener must return an open file descriptor (passing os.open as opener
results in functionality similar to passing None).
See https://docs.python.org/3/library/functions.html#open for details.
Returns:
file: a file-like object.
"""
self._check_kwargs(kwargs)
return open( # type: ignore
path,
mode,
buffering=buffering,
encoding=encoding,
errors=errors,
newline=newline,
closefd=closefd,
opener=opener,
)
def _copy(
self,
src_path: str,
dst_path: str,
overwrite: bool = False,
**kwargs: Any,
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
self._check_kwargs(kwargs)
if os.path.exists(dst_path) and not overwrite:
logger = logging.getLogger(__name__)
logger.error("Destination file {} already exists.".format(dst_path))
return False
try:
shutil.copyfile(src_path, dst_path)
return True
except Exception as e:
logger = logging.getLogger(__name__)
logger.error("Error in file copy - {}".format(str(e)))
return False
def _exists(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.exists(path)
def _isfile(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isfile(path)
def _isdir(self, path: str, **kwargs: Any) -> bool:
self._check_kwargs(kwargs)
return os.path.isdir(path)
def _ls(self, path: str, **kwargs: Any) -> List[str]:
self._check_kwargs(kwargs)
return os.listdir(path)
def _mkdirs(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
try:
os.makedirs(path, exist_ok=True)
except OSError as e:
# EEXIST it can still happen if multiple processes are creating the dir
if e.errno != errno.EEXIST:
raise
def _rm(self, path: str, **kwargs: Any) -> None:
self._check_kwargs(kwargs)
os.remove(path)
class PathManager:
"""
A class for users to open generic paths or translate generic paths to file names.
"""
_PATH_HANDLERS: MutableMapping[str, PathHandler] = OrderedDict()
_NATIVE_PATH_HANDLER = NativePathHandler()
@staticmethod
def __get_path_handler(path: str) -> PathHandler:
"""
Finds a PathHandler that supports the given path. Falls back to the native
PathHandler if no other handler is found.
Args:
path (str): URI path to resource
Returns:
handler (PathHandler)
"""
for p in PathManager._PATH_HANDLERS.keys():
if path.startswith(p):
return PathManager._PATH_HANDLERS[p]
return PathManager._NATIVE_PATH_HANDLER
@staticmethod
def open(
path: str, mode: str = "r", buffering: int = -1, **kwargs: Any
) -> Union[IO[str], IO[bytes]]:
"""
Open a stream to a URI, similar to the built-in `open`.
Args:
path (str): A URI supported by this PathHandler
mode (str): Specifies the mode in which the file is opened. It defaults
to 'r'.
buffering (int): An optional integer used to set the buffering policy.
Pass 0 to switch buffering off and an integer >= 1 to indicate the
size in bytes of a fixed-size chunk buffer. When no buffering
argument is given, the default buffering policy depends on the
underlying I/O implementation.
Returns:
file: a file-like object.
"""
return PathManager.__get_path_handler(path)._open( # type: ignore
path, mode, buffering=buffering, **kwargs
)
@staticmethod
def copy(
src_path: str, dst_path: str, overwrite: bool = False, **kwargs: Any
) -> bool:
"""
Copies a source path to a destination path.
Args:
src_path (str): A URI supported by this PathHandler
dst_path (str): A URI supported by this PathHandler
overwrite (bool): Bool flag for forcing overwrite of existing file
Returns:
status (bool): True on success
"""
# Copying across handlers is not supported.
assert PathManager.__get_path_handler( # type: ignore
src_path
) == PathManager.__get_path_handler(dst_path)
return PathManager.__get_path_handler(src_path)._copy(
src_path, dst_path, overwrite, **kwargs
)
@staticmethod
def get_local_path(path: str, **kwargs: Any) -> str:
"""
Get a filepath which is compatible with native Python I/O such as `open`
and `os.path`.
If URI points to a remote resource, this function may download and cache
the resource to local disk.
Args:
path (str): A URI supported by this PathHandler
Returns:
local_path (str): a file path which exists on the local file system
"""
return PathManager.__get_path_handler( # type: ignore
path
)._get_local_path(path, **kwargs)
@staticmethod
def exists(path: str, **kwargs: Any) -> bool:
"""
Checks if there is a resource at the given URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path exists
"""
return PathManager.__get_path_handler(path)._exists( # type: ignore
path, **kwargs
)
@staticmethod
def isfile(path: str, **kwargs: Any) -> bool:
"""
Checks if there the resource at the given URI is a file.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a file
"""
return PathManager.__get_path_handler(path)._isfile( # type: ignore
path, **kwargs
)
@staticmethod
def isdir(path: str, **kwargs: Any) -> bool:
"""
Checks if the resource at the given URI is a directory.
Args:
path (str): A URI supported by this PathHandler
Returns:
bool: true if the path is a directory
"""
return PathManager.__get_path_handler(path)._isdir( # type: ignore
path, **kwargs
)
@staticmethod
def ls(path: str, **kwargs: Any) -> List[str]:
"""
List the contents of the directory at the provided URI.
Args:
path (str): A URI supported by this PathHandler
Returns:
List[str]: list of contents in given path
"""
return PathManager.__get_path_handler(path)._ls( # type: ignore
path, **kwargs
)
@staticmethod
def mkdirs(path: str, **kwargs: Any) -> None:
"""
Recursive directory creation function. Like mkdir(), but makes all
intermediate-level directories needed to contain the leaf directory.
Similar to the native `os.makedirs`.
Args:
path (str): A URI supported by this PathHandler
"""
return PathManager.__get_path_handler(path)._mkdirs( # type: ignore
path, **kwargs
)
@staticmethod
def rm(path: str, **kwargs: Any) -> None:
"""
Remove the file (not directory) at the provided URI.
Args:
path (str): A URI supported by this PathHandler
"""
return PathManager.__get_path_handler(path)._rm( # type: ignore
path, **kwargs
)
@staticmethod
def register_handler(handler: PathHandler) -> None:
"""
Register a path handler associated with `handler._get_supported_prefixes`
URI prefixes.
Args:
handler (PathHandler)
"""
assert isinstance(handler, PathHandler), handler
for prefix in handler._get_supported_prefixes():
assert prefix not in PathManager._PATH_HANDLERS
PathManager._PATH_HANDLERS[prefix] = handler
# Sort path handlers in reverse order so longer prefixes take priority,
# eg: http://foo/bar before http://foo
PathManager._PATH_HANDLERS = OrderedDict(
sorted(
PathManager._PATH_HANDLERS.items(),
key=lambda t: t[0],
reverse=True,
)
)
@staticmethod
def set_strict_kwargs_checking(enable: bool) -> None:
"""
Toggles strict kwargs checking. If enabled, a ValueError is thrown if any
unused parameters are passed to a PathHandler function. If disabled, only
a warning is given.
With a centralized file API, there's a tradeoff of convenience and
correctness delegating arguments to the proper I/O layers. An underlying
`PathHandler` may support custom arguments which should not be statically
exposed on the `PathManager` function. For example, a custom `HTTPURLHandler`
may want to expose a `cache_timeout` argument for `open()` which specifies
how old a locally cached resource can be before it's refetched from the
remote server. This argument would not make sense for a `NativePathHandler`.
If strict kwargs checking is disabled, `cache_timeout` can be passed to
`PathManager.open` which will forward the arguments to the underlying
handler. By default, checking is enabled since it is innately unsafe:
multiple `PathHandler`s could reuse arguments with different semantic
meanings or types.
Args:
enable (bool)
"""
PathManager._NATIVE_PATH_HANDLER._strict_kwargs_check = enable
for handler in PathManager._PATH_HANDLERS.values():
handler._strict_kwargs_check = enable
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
import numpy as np
from typing import List, Tuple
class HistoryBuffer:
"""
Track a series of scalar values and provide access to smoothed values over a
window or the global average of the series.
"""
def __init__(self, max_length: int = 1000000):
"""
Args:
max_length: maximal number of values that can be stored in the
buffer. When the capacity of the buffer is exhausted, old
values will be removed.
"""
self._max_length: int = max_length
self._data: List[Tuple[float, float]] = [] # (value, iteration) pairs
self._count: int = 0
self._global_avg: float = 0
def update(self, value: float, iteration: float = None):
"""
Add a new scalar value produced at certain iteration. If the length
of the buffer exceeds self._max_length, the oldest element will be
removed from the buffer.
"""
if iteration is None:
iteration = self._count
if len(self._data) == self._max_length:
self._data.pop(0)
self._data.append((value, iteration))
self._count += 1
self._global_avg += (value - self._global_avg) / self._count
def latest(self):
"""
Return the latest scalar value added to the buffer.
"""
return self._data[-1][0]
def median(self, window_size: int):
"""
Return the median of the latest `window_size` values in the buffer.
"""
return np.median([x[0] for x in self._data[-window_size:]])
def avg(self, window_size: int):
"""
Return the mean of the latest `window_size` values in the buffer.
"""
return np.mean([x[0] for x in self._data[-window_size:]])
def global_avg(self):
"""
Return the mean of all the elements in the buffer. Note that this
includes those getting removed due to limited buffer storage.
"""
return self._global_avg
def values(self):
"""
Returns:
list[(number, iteration)]: content of the current buffer.
"""
return self._data
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
import functools
import logging
import os
import sys
import time
from collections import Counter
from termcolor import colored
from .file_io import PathManager
class _ColorfulFormatter(logging.Formatter):
def __init__(self, *args, **kwargs):
self._root_name = kwargs.pop("root_name") + "."
self._abbrev_name = kwargs.pop("abbrev_name", "")
if len(self._abbrev_name):
self._abbrev_name = self._abbrev_name + "."
super(_ColorfulFormatter, self).__init__(*args, **kwargs)
def formatMessage(self, record):
record.name = record.name.replace(self._root_name, self._abbrev_name)
log = super(_ColorfulFormatter, self).formatMessage(record)
if record.levelno == logging.WARNING:
prefix = colored("WARNING", "red", attrs=["blink"])
elif record.levelno == logging.ERROR or record.levelno == logging.CRITICAL:
prefix = colored("ERROR", "red", attrs=["blink", "underline"])
else:
return log
return prefix + " " + log
@functools.lru_cache() # so that calling setup_logger multiple times won't add many handlers
def setup_logger(
output=None, distributed_rank=0, *, color=True, name="fastreid", abbrev_name=None
):
"""
Args:
output (str): a file name or a directory to save log. If None, will not save log file.
If ends with ".txt" or ".log", assumed to be a file name.
Otherwise, logs will be saved to `output/log.txt`.
name (str): the root module name of this logger
abbrev_name (str): an abbreviation of the module, to avoid long names in logs.
Set to "" to not log the root module in logs.
By default, will abbreviate "detectron2" to "d2" and leave other
modules unchanged.
"""
logger = logging.getLogger(name)
logger.setLevel(logging.DEBUG)
logger.propagate = False
if abbrev_name is None:
abbrev_name = "d2" if name == "detectron2" else name
plain_formatter = logging.Formatter(
"[%(asctime)s] %(name)s %(levelname)s: %(message)s", datefmt="%m/%d %H:%M:%S"
)
# stdout logging: master only
if distributed_rank == 0:
ch = logging.StreamHandler(stream=sys.stdout)
ch.setLevel(logging.DEBUG)
if color:
formatter = _ColorfulFormatter(
colored("[%(asctime)s %(name)s]: ", "green") + "%(message)s",
datefmt="%m/%d %H:%M:%S",
root_name=name,
abbrev_name=str(abbrev_name),
)
else:
formatter = plain_formatter
ch.setFormatter(formatter)
logger.addHandler(ch)
# file logging: all workers
if output is not None:
if output.endswith(".txt") or output.endswith(".log"):
filename = output
else:
filename = os.path.join(output, "log.txt")
if distributed_rank > 0:
filename = filename + ".rank{}".format(distributed_rank)
PathManager.mkdirs(os.path.dirname(filename))
fh = logging.StreamHandler(_cached_log_stream(filename))
fh.setLevel(logging.DEBUG)
fh.setFormatter(plain_formatter)
logger.addHandler(fh)
return logger
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
return PathManager.open(filename, "a")
"""
Below are some other convenient logging methods.
They are mainly adopted from
https://github.com/abseil/abseil-py/blob/master/absl/logging/__init__.py
"""
def _find_caller():
"""
Returns:
str: module name of the caller
tuple: a hashable key to be used to identify different callers
"""
frame = sys._getframe(2)
while frame:
code = frame.f_code
if os.path.join("utils", "logger.") not in code.co_filename:
mod_name = frame.f_globals["__name__"]
if mod_name == "__main__":
mod_name = "detectron2"
return mod_name, (code.co_filename, frame.f_lineno, code.co_name)
frame = frame.f_back
_LOG_COUNTER = Counter()
_LOG_TIMER = {}
def log_first_n(lvl, msg, n=1, *, name=None, key="caller"):
"""
Log only for the first n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
key (str or tuple[str]): the string(s) can be one of "caller" or
"message", which defines how to identify duplicated logs.
For example, if called with `n=1, key="caller"`, this function
will only log the first call from the same caller, regardless of
the message content.
If called with `n=1, key="message"`, this function will log the
same content only once, even if they are called from different places.
If called with `n=1, key=("caller", "message")`, this function
will not log only if the same caller has logged the same message before.
"""
if isinstance(key, str):
key = (key,)
assert len(key) > 0
caller_module, caller_key = _find_caller()
hash_key = ()
if "caller" in key:
hash_key = hash_key + caller_key
if "message" in key:
hash_key = hash_key + (msg,)
_LOG_COUNTER[hash_key] += 1
if _LOG_COUNTER[hash_key] <= n:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n(lvl, msg, n=1, *, name=None):
"""
Log once per n times.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
_LOG_COUNTER[key] += 1
if n == 1 or _LOG_COUNTER[key] % n == 1:
logging.getLogger(name or caller_module).log(lvl, msg)
def log_every_n_seconds(lvl, msg, n=1, *, name=None):
"""
Log no more than once per n seconds.
Args:
lvl (int): the logging level
msg (str):
n (int):
name (str): name of the logger to use. Will use the caller's module by default.
"""
caller_module, key = _find_caller()
last_logged = _LOG_TIMER.get(key, None)
current_time = time.time()
if last_logged is None or current_time - last_logged >= n:
logging.getLogger(name or caller_module).log(lvl, msg)
_LOG_TIMER[key] = current_time
# def create_small_table(small_dict):
# """
# Create a small table using the keys of small_dict as headers. This is only
# suitable for small dictionaries.
# Args:
# small_dict (dict): a result dictionary of only a few items.
# Returns:
# str: the table as a string.
# """
# keys, values = tuple(zip(*small_dict.items()))
# table = tabulate(
# [values],
# headers=keys,
# tablefmt="pipe",
# floatfmt=".3f",
# stralign="center",
# numalign="center",
# )
# return table
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
# based on: https://github.com/PhilJd/contiguous_pytorch_params/blob/master/contiguous_params/params.py
from collections import OrderedDict
import torch
class ContiguousParams:
def __init__(self, parameters):
# Create a list of the parameters to prevent emptying an iterator.
self._parameters = parameters
self._param_buffer = []
self._grad_buffer = []
self._group_dict = OrderedDict()
self._name_buffer = []
self._init_buffers()
# Store the data pointers for each parameter into the buffer. These
# can be used to check if an operation overwrites the gradient/data
# tensor (invalidating the assumption of a contiguous buffer).
self.data_pointers = []
self.grad_pointers = []
self.make_params_contiguous()
def _init_buffers(self):
dtype = self._parameters[0]["params"][0].dtype
device = self._parameters[0]["params"][0].device
if not all(p["params"][0].dtype == dtype for p in self._parameters):
raise ValueError("All parameters must be of the same dtype.")
if not all(p["params"][0].device == device for p in self._parameters):
raise ValueError("All parameters must be on the same device.")
# Group parameters by lr and weight decay
for param_dict in self._parameters:
freeze_status = param_dict["freeze_status"]
param_key = freeze_status + '_' + str(param_dict["lr"]) + '_' + str(param_dict["weight_decay"])
if param_key not in self._group_dict:
self._group_dict[param_key] = []
self._group_dict[param_key].append(param_dict)
for key, params in self._group_dict.items():
size = sum(p["params"][0].numel() for p in params)
self._param_buffer.append(torch.zeros(size, dtype=dtype, device=device))
self._grad_buffer.append(torch.zeros(size, dtype=dtype, device=device))
self._name_buffer.append(key)
def make_params_contiguous(self):
"""Create a buffer to hold all params and update the params to be views of the buffer.
Args:
parameters: An iterable of parameters.
"""
for i, params in enumerate(self._group_dict.values()):
index = 0
for param_dict in params:
p = param_dict["params"][0]
size = p.numel()
self._param_buffer[i][index:index + size] = p.data.view(-1)
p.data = self._param_buffer[i][index:index + size].view(p.data.shape)
p.grad = self._grad_buffer[i][index:index + size].view(p.data.shape)
self.data_pointers.append(p.data.data_ptr)
self.grad_pointers.append(p.grad.data.data_ptr)
index += size
# Bend the param_buffer to use grad_buffer to track its gradients.
self._param_buffer[i].grad = self._grad_buffer[i]
def contiguous(self):
"""Return all parameters as one contiguous buffer."""
return [{
"freeze_status": self._name_buffer[i].split('_')[0],
"params": self._param_buffer[i],
"lr": float(self._name_buffer[i].split('_')[1]),
"weight_decay": float(self._name_buffer[i].split('_')[2]),
} for i in range(len(self._param_buffer))]
def original(self):
"""Return the non-flattened parameters."""
return self._parameters
def buffer_is_valid(self):
"""Verify that all parameters and gradients still use the buffer."""
i = 0
for params in self._group_dict.values():
for param_dict in params:
p = param_dict["params"][0]
data_ptr = self.data_pointers[i]
grad_ptr = self.grad_pointers[i]
if (p.data.data_ptr() != data_ptr()) or (p.grad.data.data_ptr() != grad_ptr()):
return False
i += 1
return True
def assert_buffer_is_valid(self):
if not self.buffer_is_valid():
raise ValueError(
"The data or gradient buffer has been invalidated. Please make "
"sure to use inplace operations only when updating parameters "
"or gradients.")
# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import itertools
import torch
BN_MODULE_TYPES = (
torch.nn.BatchNorm1d,
torch.nn.BatchNorm2d,
torch.nn.BatchNorm3d,
torch.nn.SyncBatchNorm,
)
@torch.no_grad()
def update_bn_stats(model, data_loader, num_iters: int = 200):
"""
Recompute and update the batch norm stats to make them more precise. During
training both BN stats and the weight are changing after every iteration, so
the running average can not precisely reflect the actual stats of the
current model.
In this function, the BN stats are recomputed with fixed weights, to make
the running average more precise. Specifically, it computes the true average
of per-batch mean/variance instead of the running average.
Args:
model (nn.Module): the model whose bn stats will be recomputed.
Note that:
1. This function will not alter the training mode of the given model.
Users are responsible for setting the layers that needs
precise-BN to training mode, prior to calling this function.
2. Be careful if your models contain other stateful layers in
addition to BN, i.e. layers whose state can change in forward
iterations. This function will alter their state. If you wish
them unchanged, you need to either pass in a submodule without
those layers, or backup the states.
data_loader (iterator): an iterator. Produce data as inputs to the model.
num_iters (int): number of iterations to compute the stats.
"""
bn_layers = get_bn_modules(model)
if len(bn_layers) == 0:
return
# In order to make the running stats only reflect the current batch, the
# momentum is disabled.
# bn.running_mean = (1 - momentum) * bn.running_mean + momentum * batch_mean
# Setting the momentum to 1.0 to compute the stats without momentum.
momentum_actual = [bn.momentum for bn in bn_layers]
for bn in bn_layers:
bn.momentum = 1.0
# Note that running_var actually means "running average of variance"
running_mean = [torch.zeros_like(bn.running_mean) for bn in bn_layers]
running_var = [torch.zeros_like(bn.running_var) for bn in bn_layers]
for ind, inputs in enumerate(itertools.islice(data_loader, num_iters)):
inputs['targets'].fill_(-1)
with torch.no_grad(): # No need to backward
model(inputs)
for i, bn in enumerate(bn_layers):
# Accumulates the bn stats.
running_mean[i] += (bn.running_mean - running_mean[i]) / (ind + 1)
running_var[i] += (bn.running_var - running_var[i]) / (ind + 1)
# We compute the "average of variance" across iterations.
assert ind == num_iters - 1, (
"update_bn_stats is meant to run for {} iterations, "
"but the dataloader stops at {} iterations.".format(num_iters, ind)
)
for i, bn in enumerate(bn_layers):
# Sets the precise bn stats.
bn.running_mean = running_mean[i]
bn.running_var = running_var[i]
bn.momentum = momentum_actual[i]
def get_bn_modules(model):
"""
Find all BatchNorm (BN) modules that are in training mode. See
fvcore.precise_bn.BN_MODULE_TYPES for a list of all modules that are
included in this search.
Args:
model (nn.Module): a model possibly containing BN modules.
Returns:
list[nn.Module]: all BN modules in the model.
"""
# Finds all the bn layers.
bn_layers = [
m for m in model.modules() if m.training and isinstance(m, BN_MODULE_TYPES)
]
return bn_layers
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
from typing import Dict, Optional
class Registry(object):
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
To create a registry (e.g. a backbone registry):
.. code-block:: python
BACKBONE_REGISTRY = Registry('BACKBONE')
To register an object:
.. code-block:: python
@BACKBONE_REGISTRY.register()
class MyBackbone():
...
Or:
.. code-block:: python
BACKBONE_REGISTRY.register(MyBackbone)
"""
def __init__(self, name: str) -> None:
"""
Args:
name (str): the name of this registry
"""
self._name: str = name
self._obj_map: Dict[str, object] = {}
def _do_register(self, name: str, obj: object) -> None:
assert (
name not in self._obj_map
), "An object named '{}' was already registered in '{}' registry!".format(
name, self._name
)
self._obj_map[name] = obj
def register(self, obj: object = None) -> Optional[object]:
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not. See docstring of this class for usage.
"""
if obj is None:
# used as a decorator
def deco(func_or_class: object) -> object:
name = func_or_class.__name__ # pyre-ignore
self._do_register(name, func_or_class)
return func_or_class
return deco
# used as a function call
name = obj.__name__ # pyre-ignore
self._do_register(name, obj)
def get(self, name: str) -> object:
ret = self._obj_map.get(name)
if ret is None:
raise KeyError(
"No object named '{}' found in '{}' registry!".format(
name, self._name
)
)
return ret
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment