# coding=utf-8 # Copyright 2021 The OneFlow Authors. All rights reserved. # Copyright (c) Facebook, Inc. and its affiliates. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import datetime import logging import math import operator import time from collections import Counter import oneflow as flow from libai.evaluation import flatten_results_dict from libai.utils import distributed as dist from libai.utils.checkpoint import Checkpointer from libai.utils.checkpoint import PeriodicCheckpointer as _PeriodicCheckpointer from libai.utils.events import EventWriter from libai.utils.timer import Timer from .trainer import HookBase # -------------------------------------------------------- # References: # https://github.com/facebookresearch/detectron2/blob/main/detectron2/engine/hooks.py # -------------------------------------------------------- """ Implement some common hooks. """ logger = logging.getLogger(__name__) class CallbackHook(HookBase): """ Create a hook using callback functions provided by the user. """ def __init__(self, *, before_train=None, after_train=None, before_step=None, after_step=None): """ Each argument is a function that takes one argument: the trainer. """ self._before_train = before_train self._before_step = before_step self._after_step = after_step self._after_train = after_train def before_train(self): if self._before_train: self._before_train(self.trainer) def after_train(self): if self._after_train: self._after_train(self.trainer) # The functions may be closures that hold reference to the trainer # Therefore, delete them to avoid circular reference. del self._before_train, self._after_train del self._before_step, self._after_step def before_step(self): if self._before_step: self._before_step(self.trainer) def after_step(self): if self._after_step: self._after_step(self.trainer) class IterationTimer(HookBase): """ Track the time spent for each iteration (each run_step call in the trainer). Print a summary in the end of training. This hook uses the time between the call to its :meth:`before_step` and :meth:`after_step` methods. Under the convention that :meth:`before_step` of all hooks should only take negligible amount of time, the :class:`IterationTimer` hook should be placed at the beginning of the list of hooks to obtain accurate timing. """ def __init__(self, warmup_iter=3): """ Args: warmup_iter (int): the number of iterations at the beginning to exclude from timing. """ self._warmup_iter = warmup_iter self._step_timer = Timer() def before_train(self): self._start_time = time.perf_counter() self._total_timer = Timer() self._total_timer.pause() def after_train(self): total_time = time.perf_counter() - self._start_time total_time_minus_hooks = self._total_timer.seconds() hook_time = total_time - total_time_minus_hooks num_iter = self.trainer.iter + 1 - self.trainer.start_iter - self._warmup_iter if num_iter > 0 and total_time_minus_hooks > 0: # Speed is meaningful only after warmup # NOTE this format is parsed by grep in some scripts logger.info( "Overall training speed: {} iterations in {} ({:.4f} s / it)".format( num_iter, str(datetime.timedelta(seconds=int(total_time_minus_hooks))), total_time_minus_hooks / num_iter, ) ) logger.info( "Total training time: {} ({} on hooks)".format( str(datetime.timedelta(seconds=int(total_time))), str(datetime.timedelta(seconds=int(hook_time))), ) ) def before_step(self): self._step_timer.reset() self._total_timer.resume() def after_step(self): # +1 because we're in after_step iter_done = self.trainer.iter - self.trainer.start_iter + 1 if iter_done >= self._warmup_iter: sec = self._step_timer.seconds() self.trainer.storage.put_scalars(time=sec) else: self._start_time = time.perf_counter() self._total_timer.reset() self._total_timer.pause() class PeriodicWriter(HookBase): """ Write events to EventStorage periodically. It is executed every ``period`` iterations and after the last iteration. """ def __init__(self, writers, period=20): """ Args: writers (list[EventWriter]): a list of EventWriter objects period (int): """ self._writers = writers for w in writers: assert isinstance(w, EventWriter), w self._period = period def after_step(self): if (self.trainer.iter + 1) % self._period == 0 or ( self.trainer.iter == self.trainer.max_iter - 1 ): for writer in self._writers: writer.write() def after_train(self): for writer in self._writers: writer.close() class PeriodicCheckpointer(_PeriodicCheckpointer, HookBase): """ Same as :class:`libai.utils.checkpoint.PeriodicCheckpointer`, but as a hook. Note that when used as a hook, it is unable to save additional data other than what's defined by the given `checkpointer`. It is executed every ``period`` iterations and after the last iteration. """ def before_train(self): self.max_iter = self.trainer.max_iter def after_step(self): self.step(self.trainer.iter) class BestCheckpointer(HookBase): """ Checkpoints best weights based off given metric. This hook should be used in conjunction to and executed after the hook that produces the metric, e.g. `EvalHook`. """ def __init__( self, eval_period: int, checkpointer: Checkpointer, val_metric: str, mode: str = "max", file_prefix: str = "model_best", ) -> None: """ Args: eval_period (int): the period `EvalHook` is set to run. checkpointer: the checkpointer object used to save checkpoints. val_metric (str): validation metric to track for best checkpoint, e.g. "acc@1" mode (str): one of {'max', 'min'}. controls whether the chosen val metric should be maximized or minimized, e.g. for "acc@1" it should be "max" file_prefix (str): the prefix of checkpoint's filename, defaults to "model_best" """ self._period = eval_period self._val_metric = val_metric assert mode in [ "max", "min", ], f'Mode "{mode}" to `BestCheckpointer` is unknown. It should be one of {"max", "min"}.' if mode == "max": self._compare = operator.gt else: self._compare = operator.lt self._checkpointer = checkpointer self._file_prefix = file_prefix self.best_metric = None self.best_iter = None def _update_best(self, val, iteration): if math.isnan(val) or math.isinf(val): return False self.best_metric = val self.best_iter = iteration return True def _best_checking(self): metric_tuple = self.trainer.storage.latest().get(self._val_metric) flag = flow.zeros(1) if dist.is_main_process(): if metric_tuple is None: logger.warning( f"Given val metric {self._val_metric} does not seem to be computed/stored. " "Will not be checkpointed based on that." ) else: latest_metric, metric_iter = metric_tuple if self.best_metric is None: if self._update_best(latest_metric, metric_iter): flag = flag + 1 logger.info( f"Saved first model at {self.best_metric:0.5f} @ {self.best_iter} steps" ) elif self._compare(latest_metric, self.best_metric): flag = flag + 1 logger.info( f"Saved best model as latest eval score for {self._val_metric} is " f"{latest_metric:0.5f}, better than last best score " f"{self.best_metric:0.5f} @ iteration {self.best_iter}." ) self._update_best(latest_metric, metric_iter) else: logger.info( f"Not saving as latest eval score for " f"{self._val_metric} is {latest_metric:0.5f}, " f"not better than best score {self.best_metric:0.5f} " f"@ iteration {self.best_iter}." ) dist.synchronize() flag = flag.to_global( sbp=flow.sbp.broadcast, placement=flow.env.all_device_placement("cpu") ) if flag.to_local().item() == 1: self._checkpointer.save(f"{self._file_prefix}") def after_step(self): # same conditions as `EvalHook` next_iter = self.trainer.iter + 1 if ( self._period > 0 and next_iter % self._period == 0 and next_iter != self.trainer.max_iter ): self._best_checking() def after_train(self): # same conditions as `EvalHook` if self.trainer.iter + 1 >= self.trainer.max_iter: self._best_checking() class EvalHook(HookBase): """ Run an evaluation function periodically, and at the end of training. It is executed every ``eval_period`` iterations and after the last iteration. """ def __init__(self, eval_period, eval_function): """ Args: eval_period (int): the period to run `eval_function`. eval_function (callable): a function which takes no arguments, and returns a nested dict of evaluation metrics. Note: This hook must be enabled in all or none workers. If you would like only certain workers to perform evaluation, give other workers a no-op function (`eval_function=lambda: None`). """ self._period = eval_period self._func = eval_function def _do_eval(self): results = self._func() if results: assert isinstance( results, dict ), "Eval function must return a dict. Got {} instead.".format(results) flattened_results = flatten_results_dict(results) # fixme: flatten_results_dict is not defined for k, v in flattened_results.items(): try: v = float(v) except Exception: raise ValueError( "[EvalHook] eval_function should return a nested dict of float. " "Got '{}: {}' instead.".format(k, v) ) self.trainer.storage.put_scalars(**flattened_results, smoothing_hint=False) # Evaluation may take different time among workers. # A barrier make them start the next iteration together. dist.synchronize() def after_step(self): next_iter = self.trainer.iter + 1 if self._period > 0 and next_iter % self._period == 0: # do the last eval in after_train if next_iter != self.trainer.max_iter: self._do_eval() def after_train(self): # This condition is to prevent the eval from running after a failed training if self.trainer.iter + 1 >= self.trainer.max_iter: self._do_eval() # func is likely a closure that holds reference to the trainer # therefore we clean it to avoid circular reference in the end del self._func class LRScheduler(HookBase): """ A hook which executes a oneflow builtin LR scheduler and summarizes the LR. It is executed after every iteration. """ def __init__(self, optimizer=None, scheduler=None): """ Args: optimizer (flow.optim.Optimizer): scheduler (flow.optim.LRScheduler): if a :class:`ParamScheduler` object, it defines the multiplier over the base LR in the optimizer. If any argument is not given, will try to obtain it from the trainer. """ self._optimizer = optimizer self._scheduler = scheduler def before_train(self): self._optimizer = self._optimizer or self.trainer.optimizer self._best_param_group_id = LRScheduler.get_best_param_group_id(self._optimizer) @staticmethod def get_best_param_group_id(optimizer): # NOTE: some heuristics on what LR to summarize # summarize the param group with most parameters largest_group = max(len(g["params"]) for g in optimizer.state_dict()["param_groups"]) if largest_group == 1: # If all groups have one parameter, # then find the most common initial LR, and use it for summary lr_count = Counter( [g["_options"]["lr"] for g in optimizer.state_dict()["param_groups"]] ) lr = lr_count.most_common()[0][0] for i, g in enumerate(optimizer.state_dict()["param_groups"]): if g["_options"]["lr"] == lr: return i else: for i, g in enumerate(optimizer.state_dict()["param_groups"]): if len(g["params"]) == largest_group: return i def after_step(self): lr = self.scheduler.get_last_lr()[self._best_param_group_id] self.trainer.storage.put_scalar("lr", lr, smoothing_hint=False) self.scheduler.step() @property def scheduler(self): return self._scheduler or self.trainer.lr_scheduler def state_dict(self): if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler): return self.scheduler.state_dict() return {} def load_state_dict(self, state_dict): if isinstance(self.scheduler, flow.optim.lr_scheduler._LRScheduler): logger.info("Loading scheduler from state_dict ...") self.scheduler.load_state_dict(state_dict)