Unverified Commit 3defa32a authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Support TP-compatible Torch AMP and Update trainer API (#27)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
parent 2b05de4c
...@@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): ...@@ -106,7 +106,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type) tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_grads = _calc_lp( no_tensor_parallel_grads = _calc_lp(
no_tensor_parallel_grads, norm_type) no_tensor_parallel_grads, norm_type)
if gpc.is_initialized(ParallelMode.TENSOR): if gpc.is_initialized(ParallelMode.TENSOR) and len(tensor_parallel_grads) > 0:
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
torch.distributed.all_reduce(tensor_parallel_norm, torch.distributed.all_reduce(tensor_parallel_norm,
op=torch.distributed.ReduceOp.SUM, op=torch.distributed.ReduceOp.SUM,
......
...@@ -6,6 +6,7 @@ import math ...@@ -6,6 +6,7 @@ import math
import torch import torch
import torch.distributed as dist import torch.distributed as dist
try: try:
from deepspeed.git_version_info import version from deepspeed.git_version_info import version
from deepspeed.moe.utils import is_moe_param from deepspeed.moe.utils import is_moe_param
...@@ -13,7 +14,7 @@ try: ...@@ -13,7 +14,7 @@ try:
from deepspeed.ops.op_builder import UtilsBuilder from deepspeed.ops.op_builder import UtilsBuilder
from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS from deepspeed.runtime.zero.config import ZERO_OPTIMIZATION_GRADIENTS
except ImportError: except ImportError:
print('DeepSpeed is required if you want to use ZeRO.') pass
from packaging import version as pkg_version from packaging import version as pkg_version
from torch._six import inf from torch._six import inf
from torch.distributed.distributed_c10d import _get_global_rank from torch.distributed.distributed_c10d import _get_global_rank
......
...@@ -21,7 +21,7 @@ try: ...@@ -21,7 +21,7 @@ try:
from deepspeed.runtime.zero.partition_parameters import * from deepspeed.runtime.zero.partition_parameters import *
from deepspeed.runtime.zero.partition_parameters import _init_external_params from deepspeed.runtime.zero.partition_parameters import _init_external_params
except ImportError: except ImportError:
print('DeepSpeed is required if you want to use ZeRO.') pass
from torch._six import inf from torch._six import inf
from torch.distributed.distributed_c10d import _get_global_rank from torch.distributed.distributed_c10d import _get_global_rank
......
...@@ -20,3 +20,4 @@ TRANSFORMS = Registry('transforms', third_party_library=[transforms]) ...@@ -20,3 +20,4 @@ TRANSFORMS = Registry('transforms', third_party_library=[transforms])
PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy') PIPE_ALLOC_POLICY = Registry('pipeline_allocation_policy')
SAMPLERS = Registry('samplers') SAMPLERS = Registry('samplers')
LR_SCHEDULERS = Registry('lr_schedulers') LR_SCHEDULERS = Registry('lr_schedulers')
SCHEDULE = Registry('schedules')
from ._trainer import Trainer from ._trainer import Trainer
from .hooks import * from .hooks import *
from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D from .metric import Loss, Accuracy2D, Accuracy3D, Accuracy2p5D, LearningRate
__all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D'] __all__ = ['Trainer', 'Loss', 'Accuracy3D', 'Accuracy2D', 'Accuracy2p5D', 'LearningRate']
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from typing import Optional
from typing import Union, List from typing import Union, List
import torch import torch
...@@ -10,12 +9,11 @@ from torch.utils.data import DataLoader ...@@ -10,12 +9,11 @@ from torch.utils.data import DataLoader
from tqdm import tqdm from tqdm import tqdm
from colossalai.builder import build_hooks from colossalai.builder import build_hooks
from colossalai.checkpointing import save_checkpoint, load_checkpoint, get_checkpoint_path
from colossalai.context import Config
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_global_dist_logger
from colossalai.utils import get_global_multitimer, is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
from colossalai.nn.data import DataParallelSampler from colossalai.nn.data import DataParallelSampler
from colossalai.utils import MultiTimer
from colossalai.utils import is_dp_rank_0, is_tp_rank_0, is_no_pp_or_last_stage
class Trainer: class Trainer:
...@@ -30,43 +28,31 @@ class Trainer: ...@@ -30,43 +28,31 @@ class Trainer:
:type hoooks_cfg: Config, optional :type hoooks_cfg: Config, optional
:type verbose: bool, optional :type verbose: bool, optional
""" """
def __init__(self, def __init__(self,
engine: Engine, engine: Engine,
hooks_cfg: Optional[Config] = None, verbose: bool = False,
verbose: bool = False): timer: MultiTimer = None):
# training-ralated params # training-ralated params
self._engine = engine self._engine = engine
self._max_epochs = float('inf') self._max_epochs = 0
self._max_steps = float('inf')
self._cur_epoch = 0 self._cur_epoch = 0
self._max_steps = 0
self._cur_step = 0 self._cur_step = 0
self._steps_per_epoch = 0
# data-related params
self._train_dataloader = None
self._test_dataloader = None
# misc params # misc params
self._display_progress = False
self._logger = get_global_dist_logger() self._logger = get_global_dist_logger()
self._verbose = verbose self._verbose = verbose
# hooks can store states in this dict, and could be consumed by other hooks # hooks can store states in this dict, and could be consumed by other hooks
self.states = {} self.states = dict()
# build hooks # build hooks
self.hooks = list() self.hooks = list()
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for train, priority = {hook.priority}', ranks=[0])
# timer # multi-timer for time benchmarking
self._timer = get_global_multitimer() self._timer = timer
@property @property
def cur_epoch(self): def cur_epoch(self):
...@@ -74,13 +60,65 @@ class Trainer: ...@@ -74,13 +60,65 @@ class Trainer:
""" """
return self._cur_epoch return self._cur_epoch
@cur_epoch.setter
def cur_epoch(self, epoch: int):
"""Set how many epochs have been processed.
"""
# allow setter for training resumption
self._cur_epoch = epoch
@property @property
def cur_step(self): def cur_step(self):
"""Returns how many iteration steps have been processed. """Returns how many iteration steps have been processed.
""" """
return self._cur_step return self._cur_step
def call_hooks(self, func, output=None): @property
def max_epochs(self):
return self._max_epochs
@property
def max_steps(self):
return self._max_steps
@property
def steps_per_epoch(self):
return self._steps_per_epoch
@property
def engine(self):
return self._engine
@engine.setter
def engine(self, engine_: Engine):
self._engine = engine_
def _set_current_step(self, epoch: int):
"""Sets current step number.
:param epoch: Step number to be set
:type epoch: int
"""
self._cur_step = epoch * self._steps_per_epoch
def _call_timer(self, action: str, item: str, *args, **kwargs) -> None:
"""Call timer funciton with a given timer name.
:param action: Function to be called on timer
:type action: str
:param item: Name of the timer
:type item: str
"""
if self._timer is not None:
getattr(self._timer, action)(item, *args, **kwargs)
def _reset_states(self) -> None:
"""Clear trainer states
"""
self.states = dict()
def _call_hooks(self, func, output=None):
"""Calls specific hooks in the current time point. """Calls specific hooks in the current time point.
:param func: A string represents the time point :param func: A string represents the time point
...@@ -95,161 +133,186 @@ class Trainer: ...@@ -95,161 +133,186 @@ class Trainer:
else: else:
getattr(hook, func)(*output) getattr(hook, func)(*output)
def exceed_max_step(self): @staticmethod
"""Checks whether the trainer exceeds the maximum number of runnning iterations. def _should_display_progress(display_progress: bool):
""" """ Only display progress on DP rank 0, TP rank 0 and PP last rank
return self._cur_step >= self._max_steps
def set_epoch(self, epoch):
"""Sets current epoch number.
:param epoch: Epoch number to be set
:type epoch: int
""" """
self._cur_epoch = epoch return display_progress and is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
def _recover_steps(self): def _train_epoch(self,
step = self.cur_step * self._engine.schedule.num_steps train_dataloader: DataLoader,
self._cur_step = step epoch: int = None,
display_progress: bool = False):
def _set_display_progress(self, display_progress: bool):
self._display_progress = display_progress and is_dp_rank_0(
) and is_tp_rank_0() and is_no_pp_or_last_stage()
def _train_epoch(self, epoch: int = None):
# set sampler epoch # set sampler epoch
if epoch is not None and \ if epoch is not None and \
hasattr(self._engine.train_dataloader, 'sampler') and \ hasattr(train_dataloader, 'sampler') and \
isinstance(self._engine.train_dataloader.sampler, DataParallelSampler): isinstance(train_dataloader.sampler, DataParallelSampler):
self._engine.train_dataloader.sampler.set_epoch(epoch) train_dataloader.sampler.set_epoch(epoch)
# set training state
self._engine.train() self._engine.train()
data_iter = iter(train_dataloader)
progress = range(self._engine.schedule.num_steps) progress = range(self._steps_per_epoch)
if self._display_progress: if display_progress:
if epoch is None: if epoch is None:
progress = tqdm(progress, desc='[Train]') progress = tqdm(progress, desc='[Train]')
else: else:
progress = tqdm(progress, desc=f'[Epoch {epoch} train]') progress = tqdm(progress, desc=f'[Epoch {epoch} train]')
# train 1 epoch # train 1 epoch
self.call_hooks('before_train_epoch') self._call_hooks('before_train_epoch')
self._timer.start('train-epoch') self._call_timer(action='start', item='train-epoch')
for _ in progress: for i in progress:
self._cur_step += 1 self._call_hooks('before_train_iter')
self._call_timer(action='start', item='train-step')
if i == self._steps_per_epoch - 1:
is_last_iteration = True
else:
is_last_iteration = False
# run 1 training step
logits, label, loss = self._engine.step(data_iter, is_last_iteration)
self._call_timer(action='stop', item='train-step', keep_in_history=True)
self._call_hooks('after_train_iter', output=(logits, label, loss))
self.call_hooks('before_train_iter') self._cur_step += 1
self._timer.start('train-step')
logits, label, loss = self._engine.step()
self._timer.stop('train-step', keep_in_history=True)
self.call_hooks('after_train_iter', output=(logits, label, loss))
if self.exceed_max_step():
# stop when max iter is reached # stop when max iter is reached
if self._exceed_max_step():
break break
self._timer.stop('train-epoch', keep_in_history=True)
self.call_hooks('after_train_epoch') self._call_timer(action='stop', item='train-epoch', keep_in_history=True)
self._timer.reset('train-step') self._call_hooks('after_train_epoch')
self._call_timer(action='reset', item='train-step')
def _eval(self, def _eval(self,
test_dataloader: DataLoader,
epoch: int = None, epoch: int = None,
return_loss: bool = True): display_progress: bool = False):
# switch engine status # switch engine status
self._engine.eval() self._engine.eval()
self.call_hooks('before_test') data_iter = iter(test_dataloader)
num_steps = len(test_dataloader)
self._call_hooks('before_test')
with torch.no_grad(): with torch.no_grad():
# prepare progress bar # prepare progress bar
progress = range(self._engine.schedule.num_steps) progress = range(num_steps)
if self._display_progress: if display_progress:
desc = 'Evaluation' desc = 'Evaluation'
if epoch is not None: if epoch is not None:
desc = '[Epoch %d val]' % epoch desc = '[Epoch %d val]' % epoch
progress = tqdm(progress, desc=desc) progress = tqdm(progress, desc=desc)
self.call_hooks('before_test_epoch') self._call_hooks('before_test_epoch')
self._timer.start('test-epoch') self._call_timer(action='start', item='test-epoch')
for _ in progress: for _ in progress:
self.call_hooks('before_test_iter') self._call_hooks('before_test_iter')
self._timer.start('test-step') self._call_timer(action='start', item='test-step')
logits, label, loss = self._engine.step( logits, label, loss = self._engine.step(data_iter, return_loss=True)
return_loss=return_loss) self._call_timer(action='stop', item='test-step', keep_in_history=True)
self._timer.stop('test-step', keep_in_history=True) self._call_hooks('after_test_iter',
self.call_hooks('after_test_iter',
output=(logits, label, loss)) output=(logits, label, loss))
self._timer.stop('test-epoch', keep_in_history=True) self._call_timer(action='stop', item='test-epoch', keep_in_history=True)
self.call_hooks('after_test_epoch') self._call_hooks('after_test_epoch')
self.call_hooks('after_test') self._call_hooks('after_test')
self._timer.reset('test-step') self._call_timer(action='reset', item='test-step')
self._timer.reset('test-epoch') self._call_timer(action='reset', item='test-epoch')
def _exceed_max_step(self):
return self._max_steps is not None and self._cur_step > self._max_steps
def fit(self, def fit(self,
train_dataloader: DataLoader, train_dataloader: DataLoader,
test_dataloader: DataLoader = None, epochs: int,
max_epochs: int = None,
max_steps: int = None, max_steps: int = None,
test_dataloader: DataLoader = None,
test_interval: int = 1, test_interval: int = 1,
display_progress: bool = False): hooks_cfg: dict = None,
display_progress: bool = False,
):
"""Trains the model to fit training data. """Trains the model to fit training data.
:param train_dataloader: DataLoader in training :param train_dataloader: DataLoader in training
:param test_dataloader: DataLoader in testing :param epochs: Maximum number of epoches
:param max_epochs: Maximum number of epoches
:param max_steps: Maximum number of running iterations :param max_steps: Maximum number of running iterations
:param test_dataloader: DataLoader in testing
:param test_interval: Interval of testing :param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration
:param display_progress: If True, the training progress will be printed :param display_progress: If True, the training progress will be printed
:type train_dataloader: DataLoader :type train_dataloader: DataLoader
:type test_dataloader: DataLoader :type epochs: int
:type max_epochs: int
:type max_steps: int :type max_steps: int
:type test_dataloader: DataLoader
:type test_interval: int :type test_interval: int
:type hooks_cfg: dict
:type display_progress: bool :type display_progress: bool
:type gradient_accumulation: int
""" """
# prepare dataloaders # set epochs and steps, consider gradient accumulation
self._train_dataloader = train_dataloader self._steps_per_epoch = len(train_dataloader) // self._engine.gradient_accumulation
self._engine.set_dataloader(self._train_dataloader, train=True) self._max_steps = max_steps
self._engine.train() self._max_epochs = epochs
# check if testing is required
should_test = False should_test = False
if test_dataloader is not None: if test_dataloader is not None:
self._test_dataloader = test_dataloader
self._engine.set_dataloader(self._test_dataloader, train=False)
should_test = True should_test = True
# decide the display_progress = self._should_display_progress(display_progress)
if max_epochs is not None:
self._max_epochs = max_epochs # reset hooks
if max_steps is not None: self._reset_states()
self._max_steps = max_steps self.hooks = list()
self._set_display_progress(display_progress)
# build hooks
if hooks_cfg is not None:
for cfg in hooks_cfg:
hook = build_hooks(cfg, self)
self.hooks.append(hook)
self.hooks.sort(key=lambda hook: hook.priority)
if self._verbose:
for hook in self.hooks:
self._logger.info(
f'build {hook.__class__.__name__} for training, priority = {hook.priority}', ranks=[0])
self._logger.info("Lower value means higher priority for calling hook function")
# start train # start train
self.call_hooks('before_train') self._engine.train()
self._call_hooks('before_train')
# recover step value if resuming training # recover step value if resuming training
if self.cur_epoch != 0:
self._recover_steps()
last_epoch = self._cur_epoch last_epoch = self._cur_epoch
if self.cur_epoch != 0:
self._set_current_step(last_epoch)
for epoch in range(last_epoch, self._max_epochs): for epoch in range(last_epoch, epochs):
self._cur_epoch += 1
# train for one epoch # train for one epoch
self._train_epoch(epoch) self._train_epoch(
train_dataloader=train_dataloader,
epoch=epoch,
display_progress=display_progress
)
# start eval # start eval
if should_test and epoch % test_interval == 0: if should_test and epoch % test_interval == 0:
self._eval(epoch, return_loss=True) self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
epoch=epoch,
)
self._cur_epoch += 1
# check for termination # check for termination
if self.exceed_max_step(): if self._exceed_max_step():
self._logger.info( self._logger.info(
f"Max number of steps {self._max_steps} has been reached, training is stopped automatically") f"Max number of steps {max_steps} has been reached, training is stopped automatically")
break break
self.call_hooks('after_train') self._call_hooks('after_train')
self._timer.reset('train-epoch') self._call_timer('reset', 'train-epoch')
def evaluate(self, def evaluate(self,
test_dataloader: DataLoader, test_dataloader: DataLoader,
...@@ -261,15 +324,13 @@ class Trainer: ...@@ -261,15 +324,13 @@ class Trainer:
:type test_dataloader: DataLoader :type test_dataloader: DataLoader
:type display_progress: bool, optional :type display_progress: bool, optional
""" """
# set dataloader # set display
self._test_dataloader = test_dataloader display_progress = self._should_display_progress(display_progress)
self._engine.set_dataloader(self._test_dataloader, train=True)
# set
self._set_display_progress(display_progress)
# eval # eval
self._eval(return_loss=True) self._eval(test_dataloader=test_dataloader,
display_progress=display_progress,
)
def predict(self, data: Union[Tensor, List[Tensor]]): def predict(self, data: Union[Tensor, List[Tensor]]):
"""Uses trained model to make a prediction for a tensor or a tensor list. """Uses trained model to make a prediction for a tensor or a tensor list.
...@@ -289,45 +350,6 @@ class Trainer: ...@@ -289,45 +350,6 @@ class Trainer:
# prepare a list of (data, label) to make it iterable # prepare a list of (data, label) to make it iterable
# for compatibility with schedule # for compatibility with schedule
simple_dataloader = [(data, None)] simple_dataloader = [(data, None)]
self._engine.set_dataloader(simple_dataloader) data_iter = iter(simple_dataloader)
output, _, _ = self._engine.step(return_loss=False) output, _, _ = self._engine.step(data_iter, return_loss=False)
return output return output
def save(self, path: str, suffix: str = ''):
"""Saves the model to a file.
:param path: Relative path of the file
:param suffix: Suffix of the file
:type path: str
:type suffix: str, optional
"""
save_path = get_checkpoint_path(path,
self._cur_epoch,
suffix=suffix)
save_checkpoint(save_path, self._cur_epoch, self._engine.get_model(),
self._engine.get_optimizer(),
self._engine.get_lr_scheduler())
def load(self,
path: str,
finetune: bool = False,
strict: bool = False):
"""Loads parameters to the model from a file.
:param path: Relative path of the file
:param finetune: Whether allows to load a part of the model
:param strict: Whether loads a model that has the same shape of parameters
:type path: str
:type finetune: bool, optional
:type strict: bool, optional
"""
last_epoch, _ = load_checkpoint(path,
self._engine.get_model(),
self._engine.get_optimizer(),
self._engine.get_lr_scheduler(),
finetune=finetune,
strict=strict)
if finetune:
self.set_epoch(0)
else:
self.set_epoch(last_epoch)
...@@ -2,10 +2,12 @@ from ._base_hook import BaseHook ...@@ -2,10 +2,12 @@ from ._base_hook import BaseHook
from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook from ._checkpoint_hook import SaveCheckpointHook, LoadCheckpointHook
from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook from ._metric_hook import LossHook, Accuracy2DHook, AccuracyHook, MetricHook
from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook from ._log_hook import LogMetricByEpochHook, TensorboardHook, LogTimingByEpochHook, LogMemoryByEpochHook
from ._lr_scheduler_hook import LRSchedulerHook
__all__ = [ __all__ = [
'BaseHook', 'MetricHook', 'BaseHook', 'MetricHook',
'LoadCheckpointHook', 'SaveCheckpointHook', 'LoadCheckpointHook', 'SaveCheckpointHook',
'LossHook', 'AccuracyHook', 'Accuracy2DHook', 'LossHook', 'AccuracyHook', 'Accuracy2DHook',
'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook', 'LogMetricByEpochHook', 'TensorboardHook', 'LogTimingByEpochHook', 'LogMemoryByEpochHook',
'LRSchedulerHook'
] ]
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
import os.path as osp import os.path as osp
import torch.distributed as dist
from colossalai.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.registry import HOOKS from colossalai.registry import HOOKS
from colossalai.trainer.hooks import BaseHook
from colossalai.trainer import Trainer from colossalai.trainer import Trainer
from colossalai.trainer.hooks import BaseHook
from colossalai.utils import is_dp_rank_0 from colossalai.utils import is_dp_rank_0
from colossalai.utils.checkpointing import get_latest_checkpoint_path, get_checkpoint_path
from colossalai.utils.checkpointing import save_checkpoint, load_checkpoint
from ._lr_scheduler_hook import LRSchedulerHook
@HOOKS.register_module @HOOKS.register_module
...@@ -33,7 +33,7 @@ class SaveCheckpointHook(BaseHook): ...@@ -33,7 +33,7 @@ class SaveCheckpointHook(BaseHook):
interval: int = 1, interval: int = 1,
checkpoint_dir: str = None, checkpoint_dir: str = None,
suffix: str = '', suffix: str = '',
priority: int = 0): priority: int = 10):
super().__init__(trainer=trainer, priority=priority) super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \ assert isinstance(trainer, Trainer), \
f'SaveCheckpointHook expects a Trainer, got {type(trainer)}' f'SaveCheckpointHook expects a Trainer, got {type(trainer)}'
...@@ -41,6 +41,16 @@ class SaveCheckpointHook(BaseHook): ...@@ -41,6 +41,16 @@ class SaveCheckpointHook(BaseHook):
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
self.suffix = suffix self.suffix = suffix
# get lr scheduler from the LRSchedulerHook before train
self._lr_scheduler = None
def before_train(self):
# check if lr scheduler is present in LRSchedulerHook
for hook in self.trainer.hooks:
if isinstance(hook, LRSchedulerHook):
self._lr_scheduler = hook.lr_scheduler
break
def after_train_epoch(self): def after_train_epoch(self):
"""Saves the model after a training epoch. """Saves the model after a training epoch.
""" """
...@@ -48,14 +58,18 @@ class SaveCheckpointHook(BaseHook): ...@@ -48,14 +58,18 @@ class SaveCheckpointHook(BaseHook):
if self.trainer.cur_epoch % self.interval == 0: if self.trainer.cur_epoch % self.interval == 0:
# only gpus with data parallel rank equals to 0 write to the disk # only gpus with data parallel rank equals to 0 write to the disk
if is_dp_rank_0(): if is_dp_rank_0():
self.trainer.save(path=self.checkpoint_dir, suffix=self.suffix) save_path = get_checkpoint_path(self.checkpoint_dir,
self.trainer.cur_epoch,
suffix=self.suffix)
save_checkpoint(save_path,
self.trainer.cur_epoch,
self.trainer.engine.model,
self.trainer.engine.optimizer,
self._lr_scheduler)
self.logger.info( self.logger.info(
f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}') f'checkpoint for epoch {self.trainer.cur_epoch} is saved to {self.checkpoint_dir}')
# wait until everyone is done
if dist.is_initialized():
dist.barrier()
@HOOKS.register_module @HOOKS.register_module
class LoadCheckpointHook(BaseHook): class LoadCheckpointHook(BaseHook):
...@@ -81,30 +95,46 @@ class LoadCheckpointHook(BaseHook): ...@@ -81,30 +95,46 @@ class LoadCheckpointHook(BaseHook):
epoch: int = -1, epoch: int = -1,
finetune: bool = False, finetune: bool = False,
strict: bool = False, strict: bool = False,
priority: int = 10) -> None: suffix: str = '',
priority: int = 0) -> None:
super().__init__(trainer=trainer, priority=priority)
assert isinstance(trainer, Trainer), \ assert isinstance(trainer, Trainer), \
f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}' f'LoadLatestCheckpointHook excepts a Trainer, got {type(trainer)}'
self.epoch = epoch self.epoch = epoch
self.checkpoint_dir = checkpoint_dir self.checkpoint_dir = checkpoint_dir
self.finetune = finetune self.finetune = finetune
self.suffix = suffix
self.strict = strict self.strict = strict
super().__init__(trainer=trainer, priority=priority)
def before_train(self): def before_train(self):
"""Loads parameters to the model before training. """Loads parameters to the model before training.
""" """
# check if lr scheduler is present in LRSchedulerHook
lr_scheduler = None
for hook in self.trainer.hooks:
if isinstance(hook, LRSchedulerHook):
lr_scheduler = hook.lr_scheduler
break
# use latest checkpoint if epoch = -1
if self.epoch == -1: if self.epoch == -1:
path = get_latest_checkpoint_path(self.checkpoint_dir) path = get_latest_checkpoint_path(self.checkpoint_dir, suffix=self.suffix)
else: else:
path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch) path = get_checkpoint_path(self.checkpoint_dir, epoch=self.epoch, suffix=self.suffix)
if osp.exists(path): if osp.exists(path):
self.trainer.load( last_epoch, _ = load_checkpoint(path,
path, finetune=self.finetune, strict=self.strict) self.trainer.engine.model,
self.trainer.engine.optimizer,
lr_scheduler,
finetune=self.finetune,
strict=self.strict)
if self.finetune:
self.trainer.cur_epoch = 0
else:
self.trainer.cur_epoch = last_epoch
self.logger.info( self.logger.info(
f'loaded checkpoint from {path}') f'loaded checkpoint from {path}')
else: else:
raise FileNotFoundError(f'checkpoint is not found at {path}') raise FileNotFoundError(f'checkpoint is not found at {path}')
# Some utilities want to load a checkpoint without distributed being initialized
if dist.is_initialized():
dist.barrier()
...@@ -5,7 +5,7 @@ import os ...@@ -5,7 +5,7 @@ import os
import os.path as osp import os.path as osp
import torch import torch
from tensorboardX import SummaryWriter from torch.utils.tensorboard import SummaryWriter
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
...@@ -13,7 +13,7 @@ from colossalai.registry import HOOKS ...@@ -13,7 +13,7 @@ from colossalai.registry import HOOKS
from colossalai.trainer._trainer import Trainer from colossalai.trainer._trainer import Trainer
from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \ from colossalai.utils import get_global_multitimer, set_global_multitimer_status, report_memory_usage, is_dp_rank_0, \
is_tp_rank_0, is_no_pp_or_last_stage is_tp_rank_0, is_no_pp_or_last_stage
from ._metric_hook import MetricHook from ._base_hook import BaseHook
def _format_number(val): def _format_number(val):
...@@ -24,7 +24,7 @@ def _format_number(val): ...@@ -24,7 +24,7 @@ def _format_number(val):
return val return val
class EpochIntervalHook(MetricHook): class EpochIntervalHook(BaseHook):
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1): def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1):
super().__init__(trainer, priority) super().__init__(trainer, priority)
self._interval = interval self._interval = interval
...@@ -45,7 +45,7 @@ class LogMetricByEpochHook(EpochIntervalHook): ...@@ -45,7 +45,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
:type priority: int, optional :type priority: int, optional
""" """
def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 1) -> None: def __init__(self, trainer: Trainer, interval: int = 1, priority: int = 10) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority) super().__init__(trainer=trainer, interval=interval, priority=priority)
self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage() self._is_rank_to_log = is_dp_rank_0() and is_tp_rank_0() and is_no_pp_or_last_stage()
...@@ -74,7 +74,7 @@ class LogMetricByEpochHook(EpochIntervalHook): ...@@ -74,7 +74,7 @@ class LogMetricByEpochHook(EpochIntervalHook):
@HOOKS.register_module @HOOKS.register_module
class TensorboardHook(MetricHook): class TensorboardHook(BaseHook):
"""Specialized Hook to record the metric to Tensorboard. """Specialized Hook to record the metric to Tensorboard.
:param trainer: Trainer attached with current hook :param trainer: Trainer attached with current hook
...@@ -85,59 +85,71 @@ class TensorboardHook(MetricHook): ...@@ -85,59 +85,71 @@ class TensorboardHook(MetricHook):
:type priority: int, optional :type priority: int, optional
""" """
def __init__(self, trainer: Trainer, log_dir: str, priority: int = 1) -> None: def __init__(self,
trainer: Trainer,
log_dir: str,
dp_rank_0_only: bool = True,
tp_rank_0_only: bool = True,
priority: int = 10,
) -> None:
super().__init__(trainer=trainer, priority=priority) super().__init__(trainer=trainer, priority=priority)
self._is_rank_to_log = is_no_pp_or_last_stage()
if self._is_rank_to_log: # create log dir
if not gpc.is_initialized(ParallelMode.GLOBAL) or gpc.get_global_rank() == 0:
os.makedirs(log_dir, exist_ok=True)
# determine the ranks to generate tensorboard logs
self._is_valid_rank_to_log = is_no_pp_or_last_stage()
if dp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_dp_rank_0()
if tp_rank_0_only:
self._is_valid_rank_to_log = self._is_valid_rank_to_log and is_tp_rank_0()
if self._is_valid_rank_to_log:
# create workspace on only one rank # create workspace on only one rank
if gpc.is_initialized(ParallelMode.GLOBAL): if gpc.is_initialized(ParallelMode.GLOBAL):
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
else: else:
rank = 0 rank = 0
log_dir = osp.join(log_dir, f'rank_{rank}')
# create workspace # create workspace
if not osp.exists(log_dir): log_dir = osp.join(log_dir, f'rank_{rank}')
os.makedirs(log_dir) os.makedirs(log_dir, exist_ok=True)
self.writer = SummaryWriter( self.writer = SummaryWriter(
log_dir=log_dir, filename_suffix=f'_rank_{rank}') log_dir=log_dir, filename_suffix=f'_rank_{rank}')
def after_train_iter(self, *args): def _log_by_iter(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items(): for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only: if metric_calculator.epoch_only:
continue continue
val = metric_calculator.get_last_step_value() val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(
f'{metric_name}/train', val, self.trainer.cur_step)
def after_test_iter(self, *args): if self._is_valid_rank_to_log:
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items(): self.writer.add_scalar(f'{metric_name}/{mode}', val,
if metric_calculator.epoch_only:
continue
val = metric_calculator.get_last_step_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val,
self.trainer.cur_step) self.trainer.cur_step)
def after_test_epoch(self): def _log_by_epoch(self, mode: str):
for metric_name, metric_calculator in self.trainer.states['metrics']['test'].items(): for metric_name, metric_calculator in self.trainer.states['metrics'][mode].items():
if metric_calculator.epoch_only: if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value() val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log: if self._is_valid_rank_to_log:
self.writer.add_scalar(f'{metric_name}/test', val, self.writer.add_scalar(f'{metric_name}/{mode}', val,
self.trainer.cur_step) self.trainer.cur_step)
def after_test_iter(self, *args):
self._log_by_iter(mode='test')
def after_test_epoch(self):
self._log_by_epoch(mode='test')
def after_train_iter(self, *args):
self._log_by_iter(mode='train')
def after_train_epoch(self): def after_train_epoch(self):
for metric_name, metric_calculator in self.trainer.states['metrics']['train'].items(): self._log_by_epoch(mode='train')
if metric_calculator.epoch_only:
val = metric_calculator.get_accumulated_value()
if self._is_rank_to_log:
self.writer.add_scalar(f'{metric_name}/train', val,
self.trainer.cur_step)
@HOOKS.register_module @HOOKS.register_module
...@@ -157,7 +169,7 @@ class LogTimingByEpochHook(EpochIntervalHook): ...@@ -157,7 +169,7 @@ class LogTimingByEpochHook(EpochIntervalHook):
def __init__(self, def __init__(self,
trainer: Trainer, trainer: Trainer,
interval: int = 1, interval: int = 1,
priority: int = 1, priority: int = 10,
log_eval: bool = True log_eval: bool = True
) -> None: ) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority) super().__init__(trainer=trainer, interval=interval, priority=priority)
...@@ -217,7 +229,7 @@ class LogMemoryByEpochHook(EpochIntervalHook): ...@@ -217,7 +229,7 @@ class LogMemoryByEpochHook(EpochIntervalHook):
def __init__(self, def __init__(self,
trainer: Trainer, trainer: Trainer,
interval: int = 1, interval: int = 1,
priority: int = 1, priority: int = 10,
log_eval: bool = True log_eval: bool = True
) -> None: ) -> None:
super().__init__(trainer=trainer, interval=interval, priority=priority) super().__init__(trainer=trainer, interval=interval, priority=priority)
......
from torch import Tensor
from colossalai.builder import build_lr_scheduler
from colossalai.registry import HOOKS
from ._metric_hook import MetricHook
from .._trainer import Trainer
from ..metric import LearningRate
@HOOKS.register_module
class LRSchedulerHook(MetricHook):
"""Build LR scheduler
:param trainer: Trainer attached with current hook
:type trainer: Trainer
:param lr_scheduler_cfg: The config of LR scheduler
:type lr_scheduler_cfg: dict
:param by_epoch: If `True`, the LR will be scheduled every epoch. Else, the LR will be scheduled every batch. Defaults to `True`.
:type by_epoch: bool
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int, optional
"""
def __init__(self,
trainer: Trainer,
lr_scheduler_cfg: dict,
by_epoch: bool = True,
store_lr_in_state: bool = True,
priority: int = 1,
):
super().__init__(trainer=trainer, priority=priority)
self.by_epoch = by_epoch
if by_epoch:
total_steps = trainer.max_epochs
else:
total_steps = trainer.max_epochs * trainer.steps_per_epoch
if trainer.max_steps is not None:
total_steps = min(total_steps, trainer.max_steps)
lr_scheduler_cfg['total_steps'] = total_steps
self.lr_scheduler = build_lr_scheduler(
lr_scheduler_cfg, trainer.engine.optimizer)
if store_lr_in_state:
self.trainer.states['metrics']['train']['lr'] = LearningRate(epoch_only=by_epoch,
initial_lr=self.lr_scheduler.get_lr()[0])
def after_train_epoch(self):
if self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
def after_train_iter(self, output: Tensor, label: Tensor, loss: Tensor):
if not self.by_epoch:
self.lr_scheduler.step()
self.trainer.states['metrics']['train']['lr'].update(self.lr_scheduler.get_lr()[0])
...@@ -21,9 +21,12 @@ class MetricHook(BaseHook): ...@@ -21,9 +21,12 @@ class MetricHook(BaseHook):
:type priority: int :type priority: int
""" """
def __init__(self, trainer: Trainer, priority: int): def __init__(self,
trainer: Trainer,
priority: int,
):
super().__init__(trainer, priority) super().__init__(trainer, priority)
self._is_stage_to_log = is_no_pp_or_last_stage() self._is_stage_to_compute = is_no_pp_or_last_stage()
self._check_metric_states_initialization() self._check_metric_states_initialization()
def _check_metric_states_initialization(self): def _check_metric_states_initialization(self):
...@@ -41,33 +44,34 @@ class LossHook(MetricHook): ...@@ -41,33 +44,34 @@ class LossHook(MetricHook):
:type priority: int, optional :type priority: int, optional
""" """
def __init__(self, trainer: Trainer, priority: int = 10): def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority) super().__init__(trainer, priority)
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric = Loss(epoch_only=False) self.train_loss = Loss(epoch_only=False)
self.test_loss = Loss(epoch_only=True)
# register the metric calculator # register the metric calculator
self.trainer.states['metrics']['train'][ self.trainer.states['metrics']['train'][
self.metric.__class__.__name__] = self.metric self.train_loss.__class__.__name__] = self.train_loss
self.trainer.states['metrics']['test'][ self.trainer.states['metrics']['test'][
self.metric.__class__.__name__] = self.metric self.test_loss.__class__.__name__] = self.test_loss
def before_train_epoch(self): def before_train_epoch(self):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.reset() self.train_loss.reset()
def after_train_iter(self, logits, label, loss): def after_train_iter(self, logits, label, loss):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.update(loss) self.train_loss.update(loss)
def before_test_epoch(self): def before_test_epoch(self):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.reset() self.test_loss.reset()
def after_test_iter(self, logits, label, loss): def after_test_iter(self, logits, label, loss):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.update(loss) self.test_loss.update(loss)
@HOOKS.register_module @HOOKS.register_module
...@@ -81,10 +85,10 @@ class Accuracy2DHook(MetricHook): ...@@ -81,10 +85,10 @@ class Accuracy2DHook(MetricHook):
:type priority: int, optional :type priority: int, optional
""" """
def __init__(self, trainer: Trainer, priority: int = 10): def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority) super().__init__(trainer, priority)
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric = Accuracy2D(epoch_only=True) self.metric = Accuracy2D(epoch_only=True)
# register the metric # register the metric
...@@ -92,20 +96,20 @@ class Accuracy2DHook(MetricHook): ...@@ -92,20 +96,20 @@ class Accuracy2DHook(MetricHook):
self.metric.__class__.__name__] = self.metric self.metric.__class__.__name__] = self.metric
def before_test(self): def before_test(self):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.reset() self.metric.reset()
def after_test_iter(self, logits, label, *args): def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.update(logits, label) self.metric.update(logits, label)
@HOOKS.register_module @HOOKS.register_module
class Accuracy2p5DHook(MetricHook): class Accuracy2p5DHook(MetricHook):
def __init__(self, trainer: Trainer, priority: int = 10): def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority) super().__init__(trainer, priority)
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric = Accuracy2p5D(epoch_only=True) self.metric = Accuracy2p5D(epoch_only=True)
# register the metric # register the metric
...@@ -113,11 +117,11 @@ class Accuracy2p5DHook(MetricHook): ...@@ -113,11 +117,11 @@ class Accuracy2p5DHook(MetricHook):
self.metric.__class__.__name__] = self.metric self.metric.__class__.__name__] = self.metric
def before_test(self): def before_test(self):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.reset() self.metric.reset()
def after_test_iter(self, logits, label, *args): def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.update(logits, label) self.metric.update(logits, label)
...@@ -138,7 +142,7 @@ class Accuracy3DHook(MetricHook): ...@@ -138,7 +142,7 @@ class Accuracy3DHook(MetricHook):
priority: int = 10): priority: int = 10):
super().__init__(trainer, priority) super().__init__(trainer, priority)
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric = Accuracy3D(epoch_only=True, self.metric = Accuracy3D(epoch_only=True,
input_parallel_mode=input_parallel_mode, input_parallel_mode=input_parallel_mode,
weight_parallel_mode=weight_parallel_mode) weight_parallel_mode=weight_parallel_mode)
...@@ -148,11 +152,11 @@ class Accuracy3DHook(MetricHook): ...@@ -148,11 +152,11 @@ class Accuracy3DHook(MetricHook):
self.metric.__class__.__name__] = self.metric self.metric.__class__.__name__] = self.metric
def before_test(self): def before_test(self):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.reset() self.metric.reset()
def after_test_iter(self, logits, label, *args): def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.update(logits, label) self.metric.update(logits, label)
...@@ -166,10 +170,10 @@ class AccuracyHook(MetricHook): ...@@ -166,10 +170,10 @@ class AccuracyHook(MetricHook):
:type priority: int :type priority: int
""" """
def __init__(self, trainer: Trainer, priority: int = 10): def __init__(self, trainer: Trainer, priority: int = 0):
super().__init__(trainer, priority) super().__init__(trainer, priority)
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric = Accuracy(epoch_only=True) self.metric = Accuracy(epoch_only=True)
# register the metric # register the metric
...@@ -177,9 +181,9 @@ class AccuracyHook(MetricHook): ...@@ -177,9 +181,9 @@ class AccuracyHook(MetricHook):
self.metric.__class__.__name__] = self.metric self.metric.__class__.__name__] = self.metric
def before_test(self): def before_test(self):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.reset() self.metric.reset()
def after_test_iter(self, logits, label, *args): def after_test_iter(self, logits, label, *args):
if self._is_stage_to_log: if self._is_stage_to_compute:
self.metric.update(logits, label) self.metric.update(logits, label)
...@@ -126,6 +126,33 @@ class Loss(Metric): ...@@ -126,6 +126,33 @@ class Loss(Metric):
return a < b return a < b
class LearningRate(Metric):
"""A metric collector for learning rate.
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
"""
def __init__(self, epoch_only: bool, initial_lr: float = 0.):
super().__init__(epoch_only=epoch_only)
self.lr = 0.
def reset(self) -> None:
pass
def update(self, lr) -> None:
self.lr = lr
def get_last_step_value(self):
return self.lr
def get_accumulated_value(self):
return self.lr
def is_better(a, b) -> bool:
pass
class Accuracy(Metric): class Accuracy(Metric):
"""A metric collector for accuracy. It only works for classification """A metric collector for accuracy. It only works for classification
tasks. tasks.
......
...@@ -5,9 +5,9 @@ from typing import Tuple ...@@ -5,9 +5,9 @@ from typing import Tuple
import torch import torch
from .context import Config from colossalai.context import Config
from .context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from .core import global_context as gpc from colossalai.core import global_context as gpc
__all__ = [ __all__ = [
'get_checkpoint_path', 'get_checkpoint_path',
......
...@@ -27,7 +27,7 @@ def sync_model_param_in_dp(model): ...@@ -27,7 +27,7 @@ def sync_model_param_in_dp(model):
:param model: A pyTorch nn.model on whose parameters you check the consistency :param model: A pyTorch nn.model on whose parameters you check the consistency
''' '''
if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 2: if gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
for param in model.parameters(): for param in model.parameters():
ranks = gpc.get_ranks_in_group(ParallelMode.DATA) ranks = gpc.get_ranks_in_group(ParallelMode.DATA)
dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA)) dist.broadcast(param, src=ranks[0], group=gpc.get_group(ParallelMode.DATA))
......
...@@ -4,6 +4,7 @@ import os ...@@ -4,6 +4,7 @@ import os
IMG_SIZE = 224 IMG_SIZE = 224
BATCH_SIZE = 256 BATCH_SIZE = 256
NUM_EPOCHS = 100
model = dict( model = dict(
type='VanillaResNet', type='VanillaResNet',
...@@ -67,8 +68,6 @@ loss = dict( ...@@ -67,8 +68,6 @@ loss = dict(
type='CrossEntropyLoss' type='CrossEntropyLoss'
) )
max_epochs = 100
from colossalai.engine import AMP_TYPE from colossalai.engine import AMP_TYPE
fp16 = dict( fp16 = dict(
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
NUM_EPOCH = int
model = dict() model = dict()
train_data = dict() train_data = dict()
test_data = dict() test_data = dict()
optimizer = dict() optimizer = dict()
loss = dict() loss = dict()
lr_scheduler = dict()
fp16 = dict() fp16 = dict()
zero = dict() zero = dict()
gradient_handler = [] gradient_handler = []
parallel = dict() parallel = dict()
hooks = []
num_epochs = int
num_steps = int
cudnn_benchmark = True cudnn_benchmark = True
cudnn_deterministic = False cudnn_deterministic = False
......
...@@ -8,10 +8,11 @@ BATCH_SIZE = 512 ...@@ -8,10 +8,11 @@ BATCH_SIZE = 512
IMG_SIZE = 32 IMG_SIZE = 32
PATCH_SIZE = 4 PATCH_SIZE = 4
DIM = 512 DIM = 512
NUM_ATTENTION_HEADS = 8 NUM_ATTENTION_HEADS = 2
SUMMA_DIM = 2 SUMMA_DIM = 2
NUM_CLASSES = 10 NUM_CLASSES = 10
DEPTH = 6 DEPTH = 1
NUM_EPOCHS = 60
train_data = dict( train_data = dict(
dataset=dict( dataset=dict(
...@@ -127,14 +128,22 @@ hooks = [ ...@@ -127,14 +128,22 @@ hooks = [
dict(type='LogMetricByEpochHook'), dict(type='LogMetricByEpochHook'),
dict(type='Accuracy2DHook'), dict(type='Accuracy2DHook'),
dict(type='LossHook'), dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'), dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
dict(type='TensorboardHook', log_dir='./tb_logs'),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
] ]
parallel = dict( parallel = dict(
pipeline=dict(size=1), pipeline=dict(size=1),
tensor=dict(size=4, mode='2d'), tensor=dict(size=1, mode='2d'),
) )
# for fp16 training # for fp16 training
...@@ -144,17 +153,11 @@ parallel = dict( ...@@ -144,17 +153,11 @@ parallel = dict(
# initial_scale=2 ** 8 # initial_scale=2 ** 8
# ) # )
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
# only needed when pipeline parallel is used # only needed when pipeline parallel is used
# schedule = dict( # schedule = dict(
# num_microbatches=8 # num_microbatches=8
# ) # )
num_epochs = 60
logging = dict( logging = dict(
root_path='./logs' root_path='./logs'
......
...@@ -14,6 +14,7 @@ except: ...@@ -14,6 +14,7 @@ except:
BATCH_SIZE = 512 BATCH_SIZE = 512
IMG_SIZE = 32 IMG_SIZE = 32
NUM_EPOCHS = 60
train_data = dict( train_data = dict(
dataset=dict( dataset=dict(
...@@ -83,6 +84,14 @@ hooks = [ ...@@ -83,6 +84,14 @@ hooks = [
), ),
dict(type='LossHook'), dict(type='LossHook'),
dict(type='TensorboardHook', log_dir='./tfb_logs'), dict(type='TensorboardHook', log_dir='./tfb_logs'),
dict(
type='LRSchedulerHook',
by_epoch=True,
lr_scheduler_cfg=dict(
type='LinearWarmupLR',
warmup_steps=5
)
),
# dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'), # dict(type='SaveCheckpointHook', interval=5, checkpoint_dir='./ckpt'),
# dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt') # dict(type='LoadCheckpointHook', epoch=20, checkpoint_dir='./ckpt')
] ]
...@@ -97,13 +106,6 @@ fp16 = dict( ...@@ -97,13 +106,6 @@ fp16 = dict(
initial_scale=2 ** 8 initial_scale=2 ** 8
) )
lr_scheduler = dict(
type='LinearWarmupLR',
warmup_epochs=5
)
num_epochs = 60
logging = dict( logging = dict(
root_path='./logs' root_path='./logs'
) )
colossalai.engine.amp.amp\_type
===============================
.. automodule:: colossalai.engine.amp.amp_type
:members:
colossalai.engine.amp.grad\_scaler
==================================
.. automodule:: colossalai.engine.amp.grad_scaler
:members:
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment