Commit 4f533dd8 authored by mibaumgartner's avatar mibaumgartner
Browse files

training

parent 5d61a79b
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import math
from typing import List, Union, Sequence
from torch.optim.lr_scheduler import _LRScheduler
from torch.optim.optimizer import Optimizer
from loguru import logger
def linear_warm_up(
iteration: int,
initial_lr: float,
num_iterations: int,
final_lr: float,
) -> float:
"""
Linear learning rate warm up
Args:
iteration: current iteration
initial_lr: initial learning rate for poly lr
num_iterations: total number of iterations for of warmup
final_lr: final learning rate of warmup
Returns:
float: learning rate
"""
assert final_lr > initial_lr
if iteration >= num_iterations:
logger.warning(f"WarmUp was stepped too often, {iteration} "
f"but only {num_iterations} were expected!")
return initial_lr + (final_lr - initial_lr) * (float(iteration) / float(num_iterations))
def poly_lr(
iteration: int,
initial_lr: float,
num_iterations: int,
gamma: float,
) -> float:
"""
initial_lr * (1 - epoch / max_epochs) ** gamma
Adapted from
https://github.com/MIC-DKFZ/nnUNet/blob/master/nnunet/training/learning_rate/poly_lr.py
https://arxiv.org/abs/1904.08128
Args:
iteration: current iteration
initial_lr: initial learning rate for poly lr
num_iterations: total number of iterations of poly lr
gamma: gamma value
Returns:
float: learning rate
"""
if iteration >= num_iterations:
logger.warning(f"PolyLR was stepped too often, {iteration} "
f"but only {num_iterations} were expected! "
f"Using {num_iterations - 1} for lr computation.")
iteration = num_iterations - 1
return initial_lr * (1 - iteration / float(num_iterations)) ** gamma
def cyclic_linear_lr(
iteration: int,
num_iterations_cycle: int,
initial_lr: float,
final_lr: float,
) -> float:
"""
Linearly cycle learning rate
Args:
iteration: current iteration
num_iterations_cycle: number of iterations per cycle
initial_lr: learning rate to start cycle
final_lr: learning rate to end cycle
Returns:
float: learning rate
"""
cycle_iteration = int(iteration) % num_iterations_cycle
lr_multiplier = 1 - (cycle_iteration / float(num_iterations_cycle))
return initial_lr + (final_lr - initial_lr) * lr_multiplier
def cosine_annealing_lr(
iteration: int,
num_iterations: int,
initial_lr: float,
final_lr: float,
):
"""
Cosine annealing NO restarts
Args:
iteration: current iteration
num_iterations: total number of iterations of coine lr
initial_lr: learning rate to start
final_lr: learning rate to end
Returns:
float: learning rate
"""
return final_lr + 0.5 * (initial_lr - final_lr) * (1 + \
math.cos(math.pi * float(iteration) / float(num_iterations)))
class LinearWarmupPolyLR(_LRScheduler):
def __init__(self,
optimizer: Optimizer,
warm_iterations: int,
warm_lr: Union[float, Sequence[float]],
poly_gamma: float,
num_iterations: int,
last_epoch: int = -1,
) -> None:
"""
Linear Warm Up LR -> Poly LR -> Cycle LR
Args:
optimizer: optimizer for lr scheduling
warm_iterations: number of warmup iterations
warm_lr: initial learning rate of warm up
poly_gamma: gamma of poly lr
num_iterations: total number of iterations (including warmup)
last_epoch: The index of the last epoch. Defaults to -1.
"""
self.num_iterations = num_iterations
# warmup
self.warm_iterations = warm_iterations
if not isinstance(warm_lr, list) and not isinstance(warm_lr, tuple):
self.warm_lr = [warm_lr] * len(optimizer.param_groups)
else:
if len(warm_lr) != len(optimizer.param_groups):
raise ValueError("Expected {} warm_lr, but got {}".format(
len(optimizer.param_groups), len(warm_lr)))
self.warm_lr = [warm_lr]
# poly lr
self.poly_iterations = self.num_iterations - self.warm_iterations
self.poly_gamma = poly_gamma
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self) -> List[float]:
"""
Compute current learning rate for each param group
"""
if self.last_epoch < self.warm_iterations:
# warm up period
lrs = [linear_warm_up(
iteration=self._step_count,
initial_lr=self.warm_lr[idx],
num_iterations=self.warm_iterations,
final_lr=base_lr,
) for idx, base_lr in enumerate(self.base_lrs)]
else:
# poly lr phase
lrs = [poly_lr(
iteration=self._step_count - self.warm_iterations,
initial_lr=base_lr,
num_iterations=self.poly_iterations,
gamma=self.poly_gamma,
) for idx, base_lr in enumerate(self.base_lrs)]
return lrs
class CycleLinear(_LRScheduler):
def __init__(self,
optimizer: Optimizer,
cycle_num_iterations: int,
cycle_initial_lr: Union[float, Sequence[float]],
cycle_final_lr:Union[float, Sequence[float]],
last_epoch: int = -1,
) -> None:
"""
Cyclic learning rates with linear decay
Args:
optimizer: optimizer for lr scheduling
cycle_num_iterations: number of iterations per cycle
cycle_initial_lr: initial learning rate of cycle
cycle_final_lr: final learning rate of cycle
last_epoch: The index of the last epoch. Defaults to -1.
"""
# cycle linear lr
self.cycle_num_iterations = cycle_num_iterations
if not isinstance(cycle_initial_lr, list) and not isinstance(cycle_initial_lr, tuple):
self.cycle_initial_lr = [cycle_initial_lr] * len(optimizer.param_groups)
else:
if len(cycle_initial_lr) != len(optimizer.param_groups):
raise ValueError("Expected {} cycle_initial_lr, but got {}".format(
len(optimizer.param_groups), len(cycle_initial_lr)))
self.cycle_initial_lr = [cycle_initial_lr]
if not isinstance(cycle_final_lr, list) and not isinstance(cycle_final_lr, tuple):
self.cycle_final_lr = [cycle_final_lr] * len(optimizer.param_groups)
else:
if len(cycle_final_lr) != len(optimizer.param_groups):
raise ValueError("Expected {} cycle_final_lr, but got {}".format(
len(optimizer.param_groups), len(cycle_final_lr)))
self.cycle_final_lr = [cycle_final_lr]
super().__init__(optimizer, last_epoch=last_epoch)
def get_lr(self) -> List[float]:
"""
Compute current learning rate for each param group
"""
lrs = [cyclic_linear_lr(
iteration=max(self._step_count - 1, 0), # init steps once
num_iterations_cycle=self.cycle_num_iterations,
initial_lr=self.cycle_initial_lr[idx],
final_lr=self.cycle_final_lr[idx],
) for idx, base_lr in enumerate(self.base_lrs)]
return lrs
class WarmUpExponential(_LRScheduler):
def __init__(self,
optimizer: Optimizer,
beta2: float,
last_epoch: int = -1,
):
"""
Expoenential learning rate warmup
warmup_lr = base_lr * 1 - exp(- (1 - beta2) * t)
for 2 * (1 - beta2)^(-1) iterations
`On the adequacy of untuned warmup for adaptive optimization`
https://arxiv.org/abs/1910.04209
Args:
optimizer: optimizer to schedule lr from (best used with Adam,
AdamW)
beta2: second beta param of Adam optimizer.
last_epoch: The index of the last epoch. Defaults to -1.
"""
self.iterations = int(2. * (1. / (1. - beta2)))
self.beta2 = beta2
logger.info(f"Running exponential warmup for {self.iterations} iterations")
self.finished = False
super().__init__(optimizer=optimizer, last_epoch=last_epoch)
def get_lr(self) -> List[float]:
"""
Compute current learning rate for each param group
"""
# last epoch is automatically handled by parent class
return [base_lr * (1 - math.exp(- (1 - self.beta2) * self.last_epoch))
for base_lr in zip(self.base_lrs)]
from nndet.training.optimizer.utils import (
get_params_no_wd_on_norm, identify_parameters, change_output_layer,
freeze_layers, unfreeze_layers,
)
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from typing import Dict, Sequence
import torch
import torch.nn as nn
import nndet.models.layers.norm as an
NORM_TYPES = [nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d,
nn.InstanceNorm1d, nn.InstanceNorm2d, nn.InstanceNorm3d,
nn.LayerNorm, nn.GroupNorm, nn.SyncBatchNorm, nn.LocalResponseNorm,
an.GroupNorm,
]
def get_params_no_wd_on_norm(model: torch.nn.Module, weight_decay: float):
"""
Apply weight decay to model but skip normalization layers
Args:
model (torch.nn.Module) : module for parameters
weight_decay (float) : weight decay for other parameters
Returns:
dict: dict with params and weight decay
See Also:
https://discuss.pytorch.org/t/weight-decay-in-the-optimizers-is-a-bad-idea-especially-with-batchnorm/16994/2
"""
identify_parameters(model, {"no_wd": NORM_TYPES})
return [
{'params': filter(lambda p: not hasattr(p, "no_wd"), model.parameters()), 'weight_decay': weight_decay},
{'params': filter(lambda p: hasattr(p, "no_wd"), model.parameters()), 'weight_decay': 0.},
]
def identify_parameters(model: torch.nn.Module,
type_mapping: Dict[str, Sequence],
check_param_exist: bool = True):
"""
Add attribute to searched module types (can be used to filter for specific modules in parameter list)
Args:
model: module to add attributes to
type_mapping: items specify types of modules to search, key specifies name of attribute
check_param_exist: check if module already has attribute. Can be used to assure that
attributes are not overwritten, but can lead to wrong results for shared parameters and
non "primitive" types
"""
for module in model.modules():
for _name, _types in type_mapping.items():
if any([isinstance(module, _type) for _type in _types]):
for param in module.parameters():
if check_param_exist:
assert not hasattr(param, _name)
setattr(param, _name, True)
def change_output_layer(model: torch.nn.Module, layer_name: str = "fc",
output_channels: int = 2, layer_type=torch.nn.Linear,
**kwargs) -> None:
"""
Change layer of module
Args:
model (torch.nn.Module): module where layer should be exchanged
layer_name (str): name of layer to exchange
output_channels (int): number of new output channels
layer_type (class): class of new layer
**kwargs: keyword arguments passed to constructor of new layer
"""
if not hasattr(model, layer_name):
raise ValueError(f"Model does not have layer {layer_name}.")
old_layer = getattr(model, layer_name)
input_channels = old_layer.in_features
setattr(model, layer_name,
layer_type(input_channels, output_channels, **kwargs))
def freeze_layers(model: torch.nn.Module) -> None:
"""
Freeze layers
Use something like "Optim([p for p in self.parameters() if p.requires_grad])"
to be sure.
Args:
model(torch.nn.Module): module to freeze
"""
for param in model.parameters():
param.requires_grad = False
def unfreeze_layers(model: torch.nn.Module) -> None:
"""
Unfreeze layers
Use something like "Optim([p for p in self.parameters() if p.requires_grad])"
to be sure.
Args:
model(torch.nn.Module): module to freeze
"""
for param in model.parameters():
param.requires_grad = True
"""
Copyright 2020 Division of Medical Image Computing, German Cancer Research Center (DKFZ), Heidelberg, Germany
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from abc import abstractmethod
from typing import Optional, Union, Callable
from loguru import logger
import torch
from torch.optim.lr_scheduler import _LRScheduler
from pytorch_lightning.callbacks import StochasticWeightAveraging
from pytorch_lightning.trainer.optimizers import _get_default_scheduler_config
from pytorch_lightning.utilities import rank_zero_warn
from nndet.training.learning_rate import CycleLinear
_AVG_FN = Callable[[torch.Tensor, torch.Tensor, torch.LongTensor], torch.FloatTensor]
class BaseSWA(StochasticWeightAveraging):
def __init__(
self,
swa_epoch_start: int,
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
update_statistics: Optional[bool] = False,
):
"""
New Base Class for Stochastic Weighted Averaging
Args:
swa_epoch_start: Epoch to start SWA weight saving.
avg_fn: Function to average saved weights. Defaults to None.
device: Device to save averaged model. Defaults to
torch.device("cpu").
update_statistics: Perform a final update of the normalization
layers. Defaults to None.
Notes: Does not support updating of norm weights after training
"""
super().__init__(
swa_epoch_start=swa_epoch_start,
swa_lrs=None,
annealing_epochs=10,
annealing_strategy="cos",
avg_fn=avg_fn,
device=device,
)
self.update_statistics = update_statistics
logger.info(f"Initialize SWA with swa epoch start {self.swa_start}")
def pl_module_contains_batch_norm(self, pl_module: 'pl.LightningModule'):
if self.update_statistics:
raise NotImplementedError("Updating the statistis of the "
"normalization layer is not suported yet.")
else:
return self.update_statistics
def on_train_epoch_start(self,
trainer: 'pl.Trainer',
pl_module: 'pl.LightningModule',
):
"""
Repalce current lr scheduler with SWA scheduler
"""
if trainer.current_epoch == self.swa_start:
optimizer = trainer.optimizers[0]
# move average model to request device.
self._average_model = self._average_model.to(self._device or pl_module.device)
_scheduler = self.get_swa_scheduler(optimizer)
self._swa_scheduler = _get_default_scheduler_config()
if not isinstance(_scheduler, dict):
_scheduler = {"scheduler": _scheduler}
self._swa_scheduler.update(_scheduler)
if trainer.lr_schedulers:
lr_scheduler = trainer.lr_schedulers[0]["scheduler"]
rank_zero_warn(f"Swapping lr_scheduler {lr_scheduler} for {self._swa_scheduler}")
trainer.lr_schedulers[0] = self._swa_scheduler
else:
trainer.lr_schedulers.append(self._swa_scheduler)
self.n_averaged = torch.tensor(0, dtype=torch.long, device=pl_module.device)
if self.swa_start <= trainer.current_epoch <= self.swa_end:
self.update_parameters(self._average_model, pl_module, self.n_averaged, self.avg_fn)
if trainer.current_epoch == self.swa_end + 1:
raise NotImplementedError("This should never happen (yet)")
@abstractmethod
def get_swa_scheduler(self, optimizer) -> Union[_LRScheduler, dict]:
"""
Generate LR scheduler for SWA
Args:
optimizer: optimizer to wrap
Returns:
Union[_LRScheduler, dict]: If a lr scheduler is returned it will
be stepped once per epoch. Can also return a whole config of
the scheduler to customize steps.
"""
raise NotImplementedError
class SWACycleLinear(BaseSWA):
def __init__(self,
swa_epoch_start: int,
cycle_initial_lr: float,
cycle_final_lr: float,
num_iterations_per_epoch: int,
avg_fn: Optional[_AVG_FN] = None,
device: Optional[Union[torch.device, str]] = torch.device("cpu"),
update_statistics: Optional[bool] = None,
):
"""
SWA based on :class:`CycleLinear`
Args:
swa_epoch_start: Epoch to start SWA weight saving.
cycle_initial_lr: initial learning rate of cycle
cycle_final_lr: final learning rate of cycle
num_iterations_per_epoch: number of train iterations per epoch
avg_fn: Function to average saved weights. Defaults to None.
device: Device to save averaged model. Defaults to
torch.device("cpu").
update_statistics: Perform a final update of the normalization
layers. Defaults to None.
"""
super().__init__(
swa_epoch_start=swa_epoch_start,
avg_fn=avg_fn,
device=device,
update_statistics=update_statistics,
)
self.cycle_initial_lr = cycle_initial_lr
self.cycle_final_lr = cycle_final_lr
self.num_iterations_per_epoch = num_iterations_per_epoch
def get_swa_scheduler(self, optimizer) -> Union[_LRScheduler, dict]:
return {
"scheduler": CycleLinear(
optimizer=optimizer,
cycle_num_iterations=self.num_iterations_per_epoch,
cycle_initial_lr=self.cycle_initial_lr,
cycle_final_lr=self.cycle_final_lr,
),
"interval": "step",
}
......@@ -19,6 +19,7 @@ This is prototype code ... Use at your own risk
This was initially part of a notebook but I needed to move it into
this scriptish functions to run it in my default pipeline
"""
import pickle
from itertools import product
from pathlib import Path
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment