Unverified Commit 0f8c7f98 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

Fixed docstring in colossalai (#171)

parent e2089c5c
...@@ -25,11 +25,13 @@ class Trainer: ...@@ -25,11 +25,13 @@ class Trainer:
called `Trainer`. called `Trainer`.
:param engine: Engine responsible for the process function :param engine: Engine responsible for the process function
:param hooks_cfg: The configuration of hooks :type engine: :class:`Engine`
:param verbose: If True, additional information will be printed :param schedule: Schedule responsible for forward and backward steps
:type engine: Engine :type schedule: :class:`BaseSchedule`, optional
:type hoooks_cfg: Config, optional :param timer: Timer used to monitor the whole training
:type verbose: bool, optional :type timer: :class:`MultiTimer`, optional
:param logger: Logger used to record the whole training
:type logger: :class:`colossalai.logging.DistributedLogger`, optional
""" """
def __init__(self, def __init__(self,
...@@ -121,6 +123,8 @@ class Trainer: ...@@ -121,6 +123,8 @@ class Trainer:
:type action: str :type action: str
:param item: Name of the timer :param item: Name of the timer
:type item: str :type item: str
:param args: args used for action function
:param kwargs: kwargs used for action function
""" """
if self._timer is not None: if self._timer is not None:
...@@ -257,18 +261,18 @@ class Trainer: ...@@ -257,18 +261,18 @@ class Trainer:
:param max_steps: Maximum number of running iterations :param max_steps: Maximum number of running iterations
:param test_dataloader: DataLoader in testing :param test_dataloader: DataLoader in testing
:param test_interval: Interval of testing :param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration :param hooks: A list of hooks used in training
:param display_progress: If True, the training progress will be printed :param display_progress: If True, the training progress will be printed
:param return_output_label: If True, the output of model and the label will be returned :param return_output_label: If True, the output of model and the label will be returned
:type return_output_label: bool
:type train_dataloader: DataLoader :type train_dataloader: DataLoader
:type epochs: int :type epochs: int
:type max_steps: int :type max_steps: int, optional
:type test_dataloader: DataLoader :type test_dataloader: DataLoader, optional
:type test_interval: int :type test_interval: int, optional
:type hooks_cfg: dict :type hooks: list, optional
:type display_progress: bool :type display_progress: bool, optional
:type gradient_accumulation: int :type return_output_label: bool, optional
""" """
# set epochs and steps, consider gradient accumulation # set epochs and steps, consider gradient accumulation
...@@ -343,9 +347,12 @@ class Trainer: ...@@ -343,9 +347,12 @@ class Trainer:
"""Evaluates the model with testing data. """Evaluates the model with testing data.
:param test_dataloader: DataLoader in testing :param test_dataloader: DataLoader in testing
:param hooks: A list of hooks used in evaluation
:param display_progress: If True, the evaluation progress will be printed :param display_progress: If True, the evaluation progress will be printed
:param return_output_label: If True, the output of model and the label will be returned :param return_output_label: If True, the output of model and the label will be returned
:type test_dataloader: DataLoader :type test_dataloader: DataLoader
:type hooks: list, optional
:type display_progress: bool, optional :type display_progress: bool, optional
:type return_output_label: bool :type return_output_label: bool
""" """
......
...@@ -12,7 +12,6 @@ class BaseHook(ABC): ...@@ -12,7 +12,6 @@ class BaseHook(ABC):
:param priority: Priority in the printing, hooks with small priority will be printed in front :param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int :type priority: int
:param trainer: Trainer attached with current hook
""" """
def __init__(self, priority: int) -> None: def __init__(self, priority: int) -> None:
...@@ -41,6 +40,8 @@ class BaseHook(ABC): ...@@ -41,6 +40,8 @@ class BaseHook(ABC):
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a training iteration. """Actions after running a training iteration.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param output: Output of the model :param output: Output of the model
:type output: torch.Tensor :type output: torch.Tensor
:param label: Labels of the input data :param label: Labels of the input data
...@@ -88,6 +89,8 @@ class BaseHook(ABC): ...@@ -88,6 +89,8 @@ class BaseHook(ABC):
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor): def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a testing iteration. """Actions after running a testing iteration.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param output: Output of the model :param output: Output of the model
:type output: Tensor :type output: Tensor
:param label: Labels of the input data :param label: Labels of the input data
...@@ -100,6 +103,8 @@ class BaseHook(ABC): ...@@ -100,6 +103,8 @@ class BaseHook(ABC):
def init_runner_states(self, trainer, key, val): def init_runner_states(self, trainer, key, val):
"""Initializes trainer's state. """Initializes trainer's state.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param key: Key of reseting state :param key: Key of reseting state
:param val: Value of reseting state :param val: Value of reseting state
""" """
......
...@@ -24,7 +24,6 @@ class SaveCheckpointHook(BaseHook): ...@@ -24,7 +24,6 @@ class SaveCheckpointHook(BaseHook):
:type suffix: str, optional :type suffix: str, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
""" """
def __init__(self, def __init__(self,
...@@ -84,7 +83,6 @@ class LoadCheckpointHook(BaseHook): ...@@ -84,7 +83,6 @@ class LoadCheckpointHook(BaseHook):
:type suffix: str, optional :type suffix: str, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
""" """
def __init__(self, def __init__(self,
......
...@@ -25,15 +25,15 @@ def _format_number(val, prec=5): ...@@ -25,15 +25,15 @@ def _format_number(val, prec=5):
class LogByEpochHook(BaseHook): class LogByEpochHook(BaseHook):
"""hook to log by epoch """Hook to log by epoch
:param logger: logger for the log :param logger: Logger for the log
:param interval: Recording interval, defaults to 1 :param interval: Recording interval, defaults to 1
:type interval: int, optional :type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
""" """
def __init__(self, def __init__(self,
logger, logger,
interval: int = 1, interval: int = 1,
...@@ -48,12 +48,12 @@ class LogByEpochHook(BaseHook): ...@@ -48,12 +48,12 @@ class LogByEpochHook(BaseHook):
@HOOKS.register_module @HOOKS.register_module
class LogMetricByStepHook(BaseHook): class LogMetricByStepHook(BaseHook):
"""hook to log metric by step """Hook to log metric by step
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
""" """
def __init__(self, priority: int = 10): def __init__(self, priority: int = 10):
super().__init__(priority) super().__init__(priority)
...@@ -62,7 +62,7 @@ class LogMetricByStepHook(BaseHook): ...@@ -62,7 +62,7 @@ class LogMetricByStepHook(BaseHook):
for metric_name, metric_calculator in trainer.states['metrics']['train'].items(): for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = \ trainer.states['step_metrics'][metric_name.lower()] = \
f'{_format_number(metric_calculator.get_last_step_value())}' f'{_format_number(metric_calculator.get_last_step_value())}'
def after_test_iter(self, trainer, *args): def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict() trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items(): for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
...@@ -72,15 +72,13 @@ class LogMetricByStepHook(BaseHook): ...@@ -72,15 +72,13 @@ class LogMetricByStepHook(BaseHook):
@HOOKS.register_module @HOOKS.register_module
class LogMetricByEpochHook(LogByEpochHook): class LogMetricByEpochHook(LogByEpochHook):
"""Specialized Hook to record the metric to log. """Specialized hook to record the metric to log.
:param logger: logger for the log :param logger: Logger for the log
:param interval: Recording interval, defaults to 1 :param interval: Recording interval, defaults to 1
:type interval: int, optional :type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
:param mode: Mode of metrics, 'train' and 'test'
""" """
def __init__(self, def __init__(self,
...@@ -116,19 +114,16 @@ class LogMetricByEpochHook(LogByEpochHook): ...@@ -116,19 +114,16 @@ class LogMetricByEpochHook(LogByEpochHook):
@HOOKS.register_module @HOOKS.register_module
class TensorboardHook(BaseHook): class TensorboardHook(BaseHook):
"""Specialized Hook to record the metric to Tensorboard. """Specialized hook to record the metric to Tensorboard.
:param log_dir: Directory of log :param log_dir: Directory of log
:type log_dir: str :type log_dir: str
:param ranks: ranks of processors :param ranks: Ranks of processors
:type ranks: typing.List :type ranks: typing.List
:param parallel_mode: Parallel mode, defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL :param parallel_mode: Parallel mode, defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional :type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
:param mode: Mode of metrics, 'train' and 'test'
:type mode: str
""" """
def __init__(self, def __init__(self,
...@@ -203,12 +198,12 @@ class TensorboardHook(BaseHook): ...@@ -203,12 +198,12 @@ class TensorboardHook(BaseHook):
@HOOKS.register_module @HOOKS.register_module
class LogTimingByEpochHook(LogByEpochHook): class LogTimingByEpochHook(LogByEpochHook):
"""Specialized Hook to write timing record to log. """Specialized hook to write timing record to log.
:param timer: Timer for the hook :param timer: Timer for the hook
:type timer: colossalai.utils.MultiTimer :type timer: :class:`colossalai.utils.MultiTimer`
:param logger: Logger for the log :param logger: Logger for the log
:type logger: colossalai.logging.DistributedLogger :type logger: :class:`colossalai.logging.DistributedLogger`
:param interval: Recording interval, defaults to 1 :param interval: Recording interval, defaults to 1
:type interval: int, optional :type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
...@@ -217,9 +212,8 @@ class LogTimingByEpochHook(LogByEpochHook): ...@@ -217,9 +212,8 @@ class LogTimingByEpochHook(LogByEpochHook):
:type log_eval: bool, optional :type log_eval: bool, optional
:param ignore_num_train_steps: Number of training steps to ignore, defaults to 0 :param ignore_num_train_steps: Number of training steps to ignore, defaults to 0
:type ignore_num_train_steps: int, optional :type ignore_num_train_steps: int, optional
:param mode: Mode of metrics, 'train' and 'test'
:param trainer: Trainer attached with current hook
""" """
def __init__(self, def __init__(self,
timer: MultiTimer, timer: MultiTimer,
logger: DistributedLogger, logger: DistributedLogger,
...@@ -285,12 +279,13 @@ class LogMemoryByEpochHook(LogByEpochHook): ...@@ -285,12 +279,13 @@ class LogMemoryByEpochHook(LogByEpochHook):
:param log_eval: Whether writes in evaluation, defaults to True :param log_eval: Whether writes in evaluation, defaults to True
:type log_eval: bool, optional :type log_eval: bool, optional
""" """
def __init__(self, def __init__(self,
logger: DistributedLogger, logger: DistributedLogger,
interval: int = 1, interval: int = 1,
priority: int = 10, priority: int = 10,
log_eval: bool = True, log_eval: bool = True,
report_cpu: bool = False, # no reference report_cpu: bool = False, # no reference
) -> None: ) -> None:
super().__init__(logger=logger, interval=interval, priority=priority) super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval self._log_eval = log_eval
......
...@@ -15,7 +15,6 @@ class LRSchedulerHook(MetricHook): ...@@ -15,7 +15,6 @@ class LRSchedulerHook(MetricHook):
:type store_lr_in_state: bool, optional :type store_lr_in_state: bool, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
""" """
def __init__( def __init__(
self, self,
......
...@@ -124,6 +124,7 @@ class LossMetric(Metric): ...@@ -124,6 +124,7 @@ class LossMetric(Metric):
""" """
return self.last_step_loss return self.last_step_loss
@staticmethod
def is_better(a, b): def is_better(a, b):
return a < b return a < b
...@@ -133,7 +134,7 @@ class LearningRateMetric(Metric): ...@@ -133,7 +134,7 @@ class LearningRateMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch :param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool :type epoch_only: bool
:param initial_lr: initial learning rate, defaults to 0.0 :param initial_lr: Initial learning rate, defaults to 0.0
:type initial_lr: float, optional :type initial_lr: float, optional
""" """
...@@ -153,6 +154,7 @@ class LearningRateMetric(Metric): ...@@ -153,6 +154,7 @@ class LearningRateMetric(Metric):
def get_accumulated_value(self): def get_accumulated_value(self):
return self.lr return self.lr
@staticmethod
def is_better(a, b) -> bool: def is_better(a, b) -> bool:
pass pass
...@@ -163,8 +165,8 @@ class AccuracyMetric(Metric): ...@@ -163,8 +165,8 @@ class AccuracyMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch :param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool :type epoch_only: bool
:param accuracy_func: accuracy function for the classification task :param accuracy_func: Accuracy function for the classification task
:type accuracy_func: typing.Callable :type accuracy_func: :class:`typing.Callable`
""" """
def __init__(self, epoch_only: bool, accuracy_func: Callable): def __init__(self, epoch_only: bool, accuracy_func: Callable):
...@@ -186,8 +188,8 @@ class AccuracyMetric(Metric): ...@@ -186,8 +188,8 @@ class AccuracyMetric(Metric):
and labels. It expects the output has logits and labels. and labels. It expects the output has logits and labels.
:param logits: The logits output of the model :param logits: The logits output of the model
:param targets: real labels of the dataset :param targets: Real labels of the dataset
:param batch_size: batch size of the task :param batch_size: Batch size of the task
""" """
if isinstance(logits, (list, tuple)): if isinstance(logits, (list, tuple)):
logits = logits[0] logits = logits[0]
...@@ -211,6 +213,7 @@ class AccuracyMetric(Metric): ...@@ -211,6 +213,7 @@ class AccuracyMetric(Metric):
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA) self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
return (self.accumulated_correct / self.accumulated_sum).item() return (self.accumulated_correct / self.accumulated_sum).item()
@staticmethod
def is_better(a, b) -> bool: def is_better(a, b) -> bool:
return a > b return a > b
...@@ -223,8 +226,6 @@ class MetricHook(BaseHook): ...@@ -223,8 +226,6 @@ class MetricHook(BaseHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front :param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int :type priority: int
:param trainer: Trainer attached with current hook
:type trainer: Trainer
""" """
def __init__( def __init__(
...@@ -245,8 +246,6 @@ class LossHook(MetricHook): ...@@ -245,8 +246,6 @@ class LossHook(MetricHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
""" """
def __init__(self, priority: int = 0): def __init__(self, priority: int = 0):
...@@ -288,8 +287,6 @@ class AccuracyHook(MetricHook): ...@@ -288,8 +287,6 @@ class AccuracyHook(MetricHook):
:type accuracy_func: typing.Callable :type accuracy_func: typing.Callable
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0 :param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
""" """
def __init__(self, accuracy_func: Callable, priority: int = 0): def __init__(self, accuracy_func: Callable, priority: int = 0):
...@@ -319,8 +316,6 @@ class ThroughputMetric(Metric): ...@@ -319,8 +316,6 @@ class ThroughputMetric(Metric):
:param epoch_only: epoch only :param epoch_only: epoch only
:type epoch_only: bool :type epoch_only: bool
:param num_samples: number of samples
:param time: time
""" """
def __init__(self, epoch_only: bool): def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only) super().__init__(epoch_only=epoch_only)
...@@ -353,6 +348,7 @@ class ThroughputMetric(Metric): ...@@ -353,6 +348,7 @@ class ThroughputMetric(Metric):
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA) self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item() return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
@staticmethod
def is_better(a, b) -> bool: def is_better(a, b) -> bool:
pass pass
...@@ -363,8 +359,6 @@ class ThroughputHook(MetricHook): ...@@ -363,8 +359,6 @@ class ThroughputHook(MetricHook):
:param priority: priority of throughput hook, defaults to 10 :param priority: priority of throughput hook, defaults to 10
:type priority: int, optional :type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
""" """
def __init__(self, priority: int = 10): def __init__(self, priority: int = 10):
super().__init__(priority) super().__init__(priority)
......
...@@ -108,10 +108,10 @@ class CheckpointFunction(torch.autograd.Function): ...@@ -108,10 +108,10 @@ class CheckpointFunction(torch.autograd.Function):
def checkpoint(function, *args): def checkpoint(function, *args):
'''Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint """Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
:param function: describe the forward pass function. It should know how to handle the input tuples. :param function: Describe the forward pass function. It should know how to handle the input tuples.
:param args: tuple containing inputs to the function :param args: Tuple containing the parameters of the function
:return: Output of running function on \*args :return: Output of running function with provided args
''' """
return CheckpointFunction.apply(function, *args) return CheckpointFunction.apply(function, *args)
...@@ -19,9 +19,8 @@ __all__ = [ ...@@ -19,9 +19,8 @@ __all__ = [
def unwrap_config(config: Config): def unwrap_config(config: Config):
''' """Unwrap Config objects to normal dicts
unwrap Config objects to normal dicts """
'''
config_dict = dict() config_dict = dict()
for k, v in config.items(): for k, v in config.items():
if isinstance(v, dict): if isinstance(v, dict):
...@@ -53,18 +52,18 @@ def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''): ...@@ -53,18 +52,18 @@ def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''): def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
'''This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple. """This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
This is useful during generation and recuperation of the checkpoint. This is useful during generation and recuperation of the checkpoint.
:param checkpoint_dir: set up a directory for saving checkpoints :param checkpoint_dir: Set up a directory for saving checkpoints
:type checkpoint_dir: str :type checkpoint_dir: str
:param epoch: epoch number (indicate how many epochs have you trained this model) :param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int :type epoch: int
:param suffix: additional notation to specify the model or checkpoint, defaults to '' :param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional :type suffix: str, optional
:return: checkpoint path to be generated :return: Checkpoint path to be generated
:rtype: path :rtype: path
''' """
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix) ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename) return os.path.join(checkpoint_dir, ckpt_filename)
...@@ -77,30 +76,30 @@ def _ensure_directory_exists(filename: str): ...@@ -77,30 +76,30 @@ def _ensure_directory_exists(filename: str):
def get_latest_checkpoint_pattern(suffix: str = ''): def get_latest_checkpoint_pattern(suffix: str = ''):
'''Generate Regular expression of latest checkpoint's pattern """Generate Regular expression of latest checkpoint's pattern
:param suffix: additional notation to specify the model or checkpoint, defaults to '' :param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional :type suffix: str, optional
:return: checkpoint pattern :return: Checkpoint pattern
:rtype: regular expression :rtype: regular expression
''' """
ranks_name = _get_ranks_name() ranks_name = _get_ranks_name()
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt') ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt')
return ckpt_pattern return ckpt_pattern
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''): def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
'''This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple. """This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number. This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
:param checkpoint_dir: directory for saving checkpoints :param checkpoint_dir: Directory for saving checkpoints
:type checkpoint_dir: str :type checkpoint_dir: str
:param suffix: additional notation to specify the model or checkpoint, defaults to '' :param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional :type suffix: str, optional
:raises FileNotFoundError: raise error when we cannot find the latest checkpoint file with inputs given :raises FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given
:return: the latest checkpoint path to be retrieved :return: The latest checkpoint path to be retrieved
:rtype: path :rtype: path
''' """
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix) CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1 last_epoch = -1
...@@ -128,22 +127,22 @@ def save_checkpoint(checkpoint_path: str, ...@@ -128,22 +127,22 @@ def save_checkpoint(checkpoint_path: str,
optimizer: torch.optim.Optimizer, optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs): **kwargs):
'''Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary. """Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module. This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
:param checkpoint_path: set up a directory for saving checkpoints :param checkpoint_path: Set up a directory for saving checkpoints
:type checkpoint_path: str :type checkpoint_path: str
:param epoch: epoch number (indicate how many epochs have you trained this model) :param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int :type epoch: int
:param model: model to be registered :param model: Model to be registered
:type model: torch.nn.Module :type model: torch.nn.Module
:param optimizer: optimizer to be registered :param optimizer: Optimizer to be registered
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to be registered, defaults to None :param lr_scheduler: lr_scheduler to be registered, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
''' """
# for compatibility with normal pytorch nn.Module # for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'): if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint() model_sd = model.state_dict_for_save_checkpoint()
...@@ -170,31 +169,31 @@ def load_checkpoint(checkpoint_path: str, ...@@ -170,31 +169,31 @@ def load_checkpoint(checkpoint_path: str,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None, lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False, finetune: bool = False,
strict: bool = True) -> Tuple: strict: bool = True) -> Tuple:
'''Loads the checkpoint file. """Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given. If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants. So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
If finetune is True, then only the weights and buffers of model should be reload. If finetune is True, then only the weights and buffers of model should be reload.
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function. If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.
:param checkpoint_path: the exact and matched checkpoint_path directory to retrieve appropriate state_dict :param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
:type checkpoint_path: str :type checkpoint_path: str
:param model: model to reload parameters and buffers :param model: Model to reload parameters and buffers
:type model: torch.nn.Module :type model: torch.nn.Module
:param optimizer: optimizer to recuperate :param optimizer: Optimizer to recuperate
:type optimizer: torch.optim.Optimizer :type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to recuperate, defaults to None :param lr_scheduler: lr_scheduler to recuperate, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional :type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
:param finetune: whether to finetune the model with new dataset or continue the pre-training, defaults to False :param finetune: Whether to finetune the model with new dataset or continue the pre-training, defaults to False
:type finetune: bool, optional :type finetune: bool, optional
:param strict: whether to strictly enforce that the keys in :param strict: Whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of :attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model., defaults to True parameters and buffers in model., defaults to True
:type strict: bool, optional :type strict: bool, optional
:raises ValueError: raise error if the model/optimizer cannot successfully be recuperated :raises ValueError: Raise error if the model/optimizer cannot successfully be recuperated
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved) :return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
:rtype: Tuple :rtype: Tuple
''' """
# Load the checkpoint. # Load the checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu') checkpoint = torch.load(checkpoint_path, map_location='cpu')
try: try:
......
...@@ -6,6 +6,8 @@ import socket ...@@ -6,6 +6,8 @@ import socket
import torch import torch
from torch._six import inf from torch._six import inf
import colossalai.context.parallel_mode
try: try:
import colossal_C import colossal_C
except: except:
...@@ -23,11 +25,13 @@ from .multi_tensor_apply import multi_tensor_applier ...@@ -23,11 +25,13 @@ from .multi_tensor_apply import multi_tensor_applier
def print_rank_0(msg: str, logger=None): def print_rank_0(msg: str, logger=None):
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu. """Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
:param msg: A str message to output :param msg: A string message to output
:param logger: python logger object, defaults to None :type msg: str
''' :param logger: Python logger object, defaults to None
:type logger: optional
"""
if gpc.get_global_rank() == 0: if gpc.get_global_rank() == 0:
if logger is None: if logger is None:
print(msg, flush=True) print(msg, flush=True)
...@@ -48,10 +52,13 @@ def free_port(): ...@@ -48,10 +52,13 @@ def free_port():
def sync_model_param(model, parallel_mode): def sync_model_param(model, parallel_mode):
'''Make sure data parameters are consistent during Data Parallel Mode """Make sure data parameters are consistent during Data Parallel Mode
:param model: A pyTorch nn.model on whose parameters you check the consistency :param model: A pyTorch nn.model on whose parameters you check the consistency
''' :param parallel_mode: Parallel mode to be checked
:type model: torch.nn.Module
:type parallel_mode: colossalai.context.ParallelMode
"""
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters(): for param in model.parameters():
ranks = gpc.get_ranks_in_group(parallel_mode) ranks = gpc.get_ranks_in_group(parallel_mode)
...@@ -124,18 +131,17 @@ def _calc_lp(grads, norm_type): ...@@ -124,18 +131,17 @@ def _calc_lp(grads, norm_type):
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients """Clips gradient norm of an iterable of parameters whose gradients are in fp32.
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
added functionality to handle model parallel parameters. Note that added functionality to handle model parallel parameters. Note that
the gradients are modified in place. the gradients are modified in place.
:param parameters: an iterable of Tensors or a single Tensor that will have gradients normalized :param parameters: An iterable of Tensors or a single Tensor that will have gradients normalized
:type parameters: (Iterable[Tensor] or Tensor) :type parameters: (Iterable[Tensor] or Tensor)
:param max_norm: max norm of the gradients :param max_norm: Max norm of the gradients
:type max_norm: float or int :type max_norm: float or int
:param norm_type: type of the used p-norm. Can be ``'inf'`` for infinity norm. :param norm_type: Type of the used p-norm. Can be ``'inf'`` for infinity norm.
:type norm_type: float or int :type norm_type: float or int
:return: Total norm of the parameters (viewed as a single vector). :return: Total norm of the parameters (viewed as a single vector).
......
...@@ -5,10 +5,10 @@ import torch ...@@ -5,10 +5,10 @@ import torch
def set_to_cuda(models): def set_to_cuda(models):
'''Send model to gpu. """Send model to gpu.
:param models: nn.module or a list of module :param models: nn.module or a list of module
''' """
if isinstance(models, list) and len(models) > 1: if isinstance(models, list) and len(models) > 1:
ret = [] ret = []
for model in models: for model in models:
...@@ -21,9 +21,8 @@ def set_to_cuda(models): ...@@ -21,9 +21,8 @@ def set_to_cuda(models):
def get_current_device(): def get_current_device():
''' """Returns the index of a currently selected device (gpu/cpu).
Returns the index of a currently selected device (gpu/cpu). """
'''
if torch.cuda.is_available(): if torch.cuda.is_available():
return torch.cuda.current_device() return torch.cuda.current_device()
else: else:
...@@ -31,18 +30,16 @@ def get_current_device(): ...@@ -31,18 +30,16 @@ def get_current_device():
def synchronize(): def synchronize():
''' """Similar to cuda.synchronize().
Similar to cuda.synchronize().
Waits for all kernels in all streams on a CUDA device to complete. Waits for all kernels in all streams on a CUDA device to complete.
''' """
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.synchronize() torch.cuda.synchronize()
def empty_cache(): def empty_cache():
''' """Similar to cuda.empty_cache()
Similar to cuda.empty_cache()
Releases all unoccupied cached memory currently held by the caching allocator. Releases all unoccupied cached memory currently held by the caching allocator.
''' """
if torch.cuda.is_available(): if torch.cuda.is_available():
torch.cuda.empty_cache() torch.cuda.empty_cache()
...@@ -21,13 +21,15 @@ T_co = TypeVar('T_co', covariant=True) ...@@ -21,13 +21,15 @@ T_co = TypeVar('T_co', covariant=True)
class DataParallelSampler(Sampler): class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism """A data sampler for distributed data parallelism
:param dataset: a Dataset instance :param dataset: A Dataset instance
:type dataset: torch.utils.data.Dataset :type dataset: torch.utils.data.Dataset
:param shuffle: whether to shuffle data, defaults to False :param shuffle: Whether to shuffle data, defaults to False
:type shuffle: bool, optional :type shuffle: bool, optional
:param seed: the random seed, defaults to 0 :param seed: The random seed, defaults to 0
:type seed: int, optional :type seed: int, optional
:param drop_last: set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False :param drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch
size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller,
defaults to False
:type drop_last: bool, optional :type drop_last: bool, optional
""" """
...@@ -116,19 +118,18 @@ def get_dataloader(dataset, ...@@ -116,19 +118,18 @@ def get_dataloader(dataset,
pin_memory=False, pin_memory=False,
num_workers=0, num_workers=0,
**kwargs): **kwargs):
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) """Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
.. note: when pipeline parallel is enabled, shuffle cannot be True .. note:: When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
as it will result in mismatch between input data on the 1st on the 1st stage and label on the last stage
stage and label on the last stage
:param dataset: a :class:utils.data.dataset dataset :param dataset: A :class:`utils.data.dataset dataset`
:param shuffle: whether to shuffle the dataset :param shuffle: Whether to shuffle the dataset
:param seed: random worker seed, defaults to 1024 :param seed: Random worker seed, defaults to 1024
:param add_sampler: add DistributedDataParallelSampelr to the dataset :param add_sampler: Add DistributedDataParallelSampelr to the dataset
:param drop_last: drop the last incomplete batch of data :param drop_last: Drop the last incomplete batch of data
:param pin_memory: whether to pin memory address in CPU memory :param pin_memory: Whether to pin memory address in CPU memory
:param num_workers: number of worker threads for this dataloader :param num_workers: Number of worker threads for this dataloader
:type dataset: :class:`torch.utils.data.Dataset` :type dataset: :class:`torch.utils.data.Dataset`
:type shuffle: bool, optional. Default is False :type shuffle: bool, optional. Default is False
...@@ -138,9 +139,9 @@ def get_dataloader(dataset, ...@@ -138,9 +139,9 @@ def get_dataloader(dataset,
:type pin_memory: bool, optional. Default is False :type pin_memory: bool, optional. Default is False
:type num_workers: int, optional. Default is 0 :type num_workers: int, optional. Default is 0
:return: a object of :class:`torch.utils.data.DataLoader` :return: A object of :class:`torch.utils.data.DataLoader`
:rtype: :class:`torch.utils.data.DataLoader` :rtype: :class:`torch.utils.data.DataLoader`
''' """
_kwargs = kwargs.copy() _kwargs = kwargs.copy()
if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1: if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
......
...@@ -17,11 +17,11 @@ class GradAccumOptimizer(ColossalaiOptimizer): ...@@ -17,11 +17,11 @@ class GradAccumOptimizer(ColossalaiOptimizer):
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps """A wrapper for the optimizer to enable gradient accumulation by skipping the steps
before accumulation size is reached before accumulation size is reached
:param optim: your optimizer object :param optim: Your optimizer object
:type optim: :class:`torch.optim.Optimizer` :type optim: :class:`torch.optim.Optimizer`
:param accumulate_size: the number of steps to accumulate gradients :param accumulate_size: The number of steps to accumulate gradients
:type accumualate_size: int :type accumulate_size: int
:param model: your model object to check if it is DDP for special handling of no_sync() context :param model: Your model object to check if it is DDP for special handling of no_sync() context
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`
""" """
...@@ -75,7 +75,7 @@ class GradAccumOptimizer(ColossalaiOptimizer): ...@@ -75,7 +75,7 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self.optim.backward_by_grad(tensor, grad) self.optim.backward_by_grad(tensor, grad)
class GradAccumDataloader(): class GradAccumDataloader:
"""A wrapper for dataloder to enable gradient accumulation by dropping the last incomplete steps. """A wrapper for dataloder to enable gradient accumulation by dropping the last incomplete steps.
For example, if a dataloader has 10 batches of data and accumulate size is 4. The model paramters will For example, if a dataloader has 10 batches of data and accumulate size is 4. The model paramters will
...@@ -83,10 +83,10 @@ class GradAccumDataloader(): ...@@ -83,10 +83,10 @@ class GradAccumDataloader():
Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader, Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader,
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches. (e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
:param dataloader: your dataloader object :param dataloader: Your dataloader object
:type dataloader: Iterable :type dataloader: Iterable
:param accumulate_size: the number of steps to accumulate gradients :param accumulate_size: The number of steps to accumulate gradients
:type accumualate_size: int :type accumulate_size: int
""" """
...@@ -127,10 +127,10 @@ class GradAccumLrSchedulerByStep(_LRScheduler): ...@@ -127,10 +127,10 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps """A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps
before accumulation size is reached before accumulation size is reached
:param lr_scheduler: your lr scheduler object :param lr_scheduler: Your lr scheduler object
:type lr_scheduler: :class:`torch.optim.lr_scheduler._LRScheduler` :type lr_scheduler: :class:`torch.optim.lr_scheduler._LRScheduler`
:param accumulate_size: the number of steps to accumulate gradients :param accumulate_size: The number of steps to accumulate gradients
:type accumualate_size: int :type accumulate_size: int
""" """
...@@ -170,14 +170,14 @@ class GradAccumLrSchedulerByStep(_LRScheduler): ...@@ -170,14 +170,14 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self.lr_scheduler.load_state_dict(state_dict) self.lr_scheduler.load_state_dict(state_dict)
class GradAccumGradientHandler(): class GradAccumGradientHandler:
"""A wrapper for the gradient handler to enable gradient accumulation by skipping the steps """A wrapper for the gradient handler to enable gradient accumulation by skipping the steps
before accumulation size is reached before accumulation size is reached
:param grad_handler: your gradient handler object :param grad_handler: Your gradient handler object
:type grad_handler: :class:`colossalai.engine.BaseGradientHandler` :type grad_handler: :class:`colossalai.engine.BaseGradientHandler`
:param accumulate_size: the number of steps to accumulate gradients :param accumulate_size: The number of steps to accumulate gradients
:type accumualate_size: int :type accumulate_size: int
""" """
......
...@@ -12,34 +12,36 @@ from colossalai.logging import get_dist_logger ...@@ -12,34 +12,36 @@ from colossalai.logging import get_dist_logger
def bytes_to_GB(val, decimal=2): def bytes_to_GB(val, decimal=2):
'''A byte-to-Gigabyte converter, defaultly using binary notation. """A byte-to-Gigabyte converter, defaultly using binary notation.
:param val: X bytes to convert :param val: X bytes to convert
:return: X' GB :return: X' GB
''' """
return round(val / (1024 * 1024 * 1024), decimal) return round(val / (1024 * 1024 * 1024), decimal)
def bytes_to_MB(val, decimal=2): def bytes_to_MB(val, decimal=2):
'''A byte-to-Megabyte converter, defaultly using binary notation. """A byte-to-Megabyte converter, defaultly using binary notation.
:param val: X bytes to convert :param val: X bytes to convert
:return: X' MB :return: X' MB
''' """
return round(val / (1024 * 1024), decimal) return round(val / (1024 * 1024), decimal)
def report_memory_usage(message, logger=None, report_cpu=False): def report_memory_usage(message, logger=None, report_cpu=False):
'''Calculate and print RAM usage (in GB) """Calculate and print RAM usage (in GB)
:param message: a prefix message to add in the log :param message: A prefix message to add in the log
:type message: str :type message: str
:param logger: an instance of :class:`colossalai.logging.DistributedLogger` :param logger: An instance of :class:`colossalai.logging.DistributedLogger`
:type logger: :class:`colossalai.logging.DistributedLogger` :type logger: :class:`colossalai.logging.DistributedLogger`, optional
:param report_cpu: whether to report CPU memory :param report_cpu: Whether to report CPU memory
:type report_cpu: bool :type report_cpu: bool, optional
:raises EnvironmentError: raise error if no distributed environment has been initialized :raises EnvironmentError: Raise error if no distributed environment has been initialized
''' """
if not gpc.is_initialized(ParallelMode.GLOBAL):
raise EnvironmentError("No distributed environment is initialized")
gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated()) gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated()) gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())
......
...@@ -5,7 +5,7 @@ class MultiTensorApply(object): ...@@ -5,7 +5,7 @@ class MultiTensorApply(object):
""" """
Apply an operation to a list of tensors efficiently Apply an operation to a list of tensors efficiently
:param chunk_size: size of a chunk :param chunk_size: Size of a chunk
:type chunk_size: int :type chunk_size: int
""" """
...@@ -22,7 +22,7 @@ class MultiTensorApply(object): ...@@ -22,7 +22,7 @@ class MultiTensorApply(object):
MultiTensorApply.import_err = err MultiTensorApply.import_err = err
def check_avail(self): def check_avail(self):
if MultiTensorApply.available == False: if not MultiTensorApply.available:
raise RuntimeError( raise RuntimeError(
"Attempted to call MultiTensorApply method, but MultiTensorApply " "Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without " "is not available, possibly because Apex was installed without "
......
...@@ -6,9 +6,8 @@ from .cuda import synchronize ...@@ -6,9 +6,8 @@ from .cuda import synchronize
class Timer: class Timer:
''' """A timer object which helps to log the execution times, and provides different tools to assess the times.
A timer object which helps to log the execution times, and provides different tools to assess the times. """
'''
def __init__(self): def __init__(self):
self._started = False self._started = False
...@@ -21,20 +20,21 @@ class Timer: ...@@ -21,20 +20,21 @@ class Timer:
return len(self._history) != 0 return len(self._history) != 0
def start(self): def start(self):
'''Fisrtly synchronize cuda, reset the clock and then start the timer. """Fisrtly synchronize cuda, reset the clock and then start the timer.
''' """
self._elapsed = 0 self._elapsed = 0
synchronize() synchronize()
self._start_time = time.time() self._start_time = time.time()
self._started = True self._started = True
def stop(self, keep_in_history: bool = False): def stop(self, keep_in_history: bool = False):
'''Stop the timer and record the start-stop time interval. """Stop the timer and record the start-stop time interval.
:param keep_in_history: whether does it record into history each start-stop interval, defaults to False
:param keep_in_history: Whether does it record into history each start-stop interval, defaults to False
:type keep_in_history: bool, optional :type keep_in_history: bool, optional
:return: start-stop interval :return: Start-stop interval
:rtype: int :rtype: int
''' """
synchronize() synchronize()
end_time = time.time() end_time = time.time()
elapsed = end_time - self._start_time elapsed = end_time - self._start_time
...@@ -45,79 +45,90 @@ class Timer: ...@@ -45,79 +45,90 @@ class Timer:
return elapsed return elapsed
def get_history_mean(self): def get_history_mean(self):
'''mean of all history start-stop time intervals. """Mean of all history start-stop time intervals.
:return: mean of time intervals
:return: Mean of time intervals
:rtype: int :rtype: int
''' """
return sum(self._history) / len(self._history) return sum(self._history) / len(self._history)
def get_history_sum(self): def get_history_sum(self):
'''add up all the start-stop time intervals. """Add up all the start-stop time intervals.
:return: sum of time intervals
:return: Sum of time intervals
:rtype: int :rtype: int
''' """
return sum(self._history) return sum(self._history)
def get_elapsed_time(self): def get_elapsed_time(self):
'''return the last start-stop time interval. *use it only when timer is not in progress* """Return the last start-stop time interval.
:return: the last time interval
.. note:: Use it only when timer is not in progress
:return: The last time interval
:rtype: int :rtype: int
''' """
assert not self._started, 'Timer is still in progress' assert not self._started, 'Timer is still in progress'
return self._elapsed return self._elapsed
def reset(self): def reset(self):
'''clear up the timer and its history """Clear up the timer and its history
''' """
self._history = [] self._history = []
self._started = False self._started = False
self._elapsed = 0 self._elapsed = 0
class MultiTimer: class MultiTimer:
'''An object contains multiple timers """An object contains multiple timers
:param on: whether the timer is enabled. Default is True :param on: Whether the timer is enabled. Default is True
:type on: bool :type on: bool, optional
''' """
def __init__(self, on: bool = True): def __init__(self, on: bool = True):
self._on = on self._on = on
self._timers = dict() self._timers = dict()
def start(self, name: str): def start(self, name: str):
'''Start namely one of the timers """Start namely one of the timers
:param name: timer's key
:param name: Timer's key
:type name: str :type name: str
''' """
if self._on: if self._on:
if name not in self._timers: if name not in self._timers:
self._timers[name] = Timer() self._timers[name] = Timer()
return self._timers[name].start() return self._timers[name].start()
def stop(self, name: str, keep_in_history: bool): def stop(self, name: str, keep_in_history: bool):
'''Stop namely one of the timers. """Stop namely one of the timers.
:param name: timer's key
:param keep_in_history: whether does it record into history each start-stop interval :param name: Timer's key
:type name: str
:param keep_in_history: Whether does it record into history each start-stop interval
:type keep_in_history: bool :type keep_in_history: bool
''' """
if self._on: if self._on:
return self._timers[name].stop(keep_in_history) return self._timers[name].stop(keep_in_history)
else: else:
return None return None
def get_timer(self, name): def get_timer(self, name):
'''Get timer by its name (from multitimer) """Get timer by its name (from multitimer)
:param name: timer's key
:return: timer with the name you give correctly :param name: Timer's key
:rtype: Timer :return: Timer with the name you give correctly
''' :rtype: Timer
"""
return self._timers[name] return self._timers[name]
def reset(self, name=None): def reset(self, name=None):
'''Reset timers. """Reset timers.
:param name: if name is designated, the named timer will be reset and others will not, defaults to None
''' :param name: If name is designated, the named timer will be reset and others will not, defaults to None
:type name: optional
"""
if self._on: if self._on:
if name is not None: if name is not None:
self._timers[name].reset() self._timers[name].reset()
......
...@@ -13,17 +13,17 @@ from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3 ...@@ -13,17 +13,17 @@ from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
def convert_to_zero(model: nn.Module, def convert_to_zero(model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
level: int, level: int,
zero_config): zero_config: dict):
""" """
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: your model object :param model: Your model object
:type model: :class:`torch.nn.Module` :type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object :param optimizer: Your optimizer object
:type optimizer: :class:`torch.optim.Optimizer` :type optimizer: :class:`torch.optim.Optimizer`
:param level: optimizer level, can be 2 or 3 :param level: Optimizer level, can be 2 or 3
:type level: int :type level: int
:param zero_config: configuration for zero :param zero_config: Configuration for zero
:type zero_config: dict :type zero_config: dict
:return: (model, optimizer) :return: (model, optimizer)
......
...@@ -95,8 +95,11 @@ class DynamicLossScaler(LossScalerBase): ...@@ -95,8 +95,11 @@ class DynamicLossScaler(LossScalerBase):
always using the highest loss scale possible without incurring overflow. always using the highest loss scale possible without incurring overflow.
Args: Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.` init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``. scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale. encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive
iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before
increasing the loss scale.
""" """
def __init__(self, def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment