Unverified Commit 0f8c7f98 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

Fixed docstring in colossalai (#171)

parent e2089c5c
......@@ -25,11 +25,13 @@ class Trainer:
called `Trainer`.
:param engine: Engine responsible for the process function
:param hooks_cfg: The configuration of hooks
:param verbose: If True, additional information will be printed
:type engine: Engine
:type hoooks_cfg: Config, optional
:type verbose: bool, optional
:type engine: :class:`Engine`
:param schedule: Schedule responsible for forward and backward steps
:type schedule: :class:`BaseSchedule`, optional
:param timer: Timer used to monitor the whole training
:type timer: :class:`MultiTimer`, optional
:param logger: Logger used to record the whole training
:type logger: :class:`colossalai.logging.DistributedLogger`, optional
"""
def __init__(self,
......@@ -121,6 +123,8 @@ class Trainer:
:type action: str
:param item: Name of the timer
:type item: str
:param args: args used for action function
:param kwargs: kwargs used for action function
"""
if self._timer is not None:
......@@ -257,18 +261,18 @@ class Trainer:
:param max_steps: Maximum number of running iterations
:param test_dataloader: DataLoader in testing
:param test_interval: Interval of testing
:param hooks_cfg: A list of hook configuration
:param hooks: A list of hooks used in training
:param display_progress: If True, the training progress will be printed
:param return_output_label: If True, the output of model and the label will be returned
:type return_output_label: bool
:type train_dataloader: DataLoader
:type epochs: int
:type max_steps: int
:type test_dataloader: DataLoader
:type test_interval: int
:type hooks_cfg: dict
:type display_progress: bool
:type gradient_accumulation: int
:type max_steps: int, optional
:type test_dataloader: DataLoader, optional
:type test_interval: int, optional
:type hooks: list, optional
:type display_progress: bool, optional
:type return_output_label: bool, optional
"""
# set epochs and steps, consider gradient accumulation
......@@ -343,9 +347,12 @@ class Trainer:
"""Evaluates the model with testing data.
:param test_dataloader: DataLoader in testing
:param hooks: A list of hooks used in evaluation
:param display_progress: If True, the evaluation progress will be printed
:param return_output_label: If True, the output of model and the label will be returned
:type test_dataloader: DataLoader
:type hooks: list, optional
:type display_progress: bool, optional
:type return_output_label: bool
"""
......
......@@ -12,7 +12,6 @@ class BaseHook(ABC):
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int
:param trainer: Trainer attached with current hook
"""
def __init__(self, priority: int) -> None:
......@@ -41,6 +40,8 @@ class BaseHook(ABC):
def after_train_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a training iteration.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param output: Output of the model
:type output: torch.Tensor
:param label: Labels of the input data
......@@ -88,6 +89,8 @@ class BaseHook(ABC):
def after_test_iter(self, trainer, output: Tensor, label: Tensor, loss: Tensor):
"""Actions after running a testing iteration.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param output: Output of the model
:type output: Tensor
:param label: Labels of the input data
......@@ -100,6 +103,8 @@ class BaseHook(ABC):
def init_runner_states(self, trainer, key, val):
"""Initializes trainer's state.
:param trainer: Trainer which is using this hook
:type trainer: :class:`Trainer`
:param key: Key of reseting state
:param val: Value of reseting state
"""
......
......@@ -24,7 +24,6 @@ class SaveCheckpointHook(BaseHook):
:type suffix: str, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self,
......@@ -84,7 +83,6 @@ class LoadCheckpointHook(BaseHook):
:type suffix: str, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self,
......
......@@ -25,15 +25,15 @@ def _format_number(val, prec=5):
class LogByEpochHook(BaseHook):
"""hook to log by epoch
"""Hook to log by epoch
:param logger: logger for the log
:param logger: Logger for the log
:param interval: Recording interval, defaults to 1
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self,
logger,
interval: int = 1,
......@@ -48,12 +48,12 @@ class LogByEpochHook(BaseHook):
@HOOKS.register_module
class LogMetricByStepHook(BaseHook):
"""hook to log metric by step
"""Hook to log metric by step
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(self, priority: int = 10):
super().__init__(priority)
......@@ -62,7 +62,7 @@ class LogMetricByStepHook(BaseHook):
for metric_name, metric_calculator in trainer.states['metrics']['train'].items():
trainer.states['step_metrics'][metric_name.lower()] = \
f'{_format_number(metric_calculator.get_last_step_value())}'
def after_test_iter(self, trainer, *args):
trainer.states['step_metrics'] = dict()
for metric_name, metric_calculator in trainer.states['metrics']['test'].items():
......@@ -72,15 +72,13 @@ class LogMetricByStepHook(BaseHook):
@HOOKS.register_module
class LogMetricByEpochHook(LogByEpochHook):
"""Specialized Hook to record the metric to log.
"""Specialized hook to record the metric to log.
:param logger: logger for the log
:param logger: Logger for the log
:param interval: Recording interval, defaults to 1
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
:param mode: Mode of metrics, 'train' and 'test'
"""
def __init__(self,
......@@ -116,19 +114,16 @@ class LogMetricByEpochHook(LogByEpochHook):
@HOOKS.register_module
class TensorboardHook(BaseHook):
"""Specialized Hook to record the metric to Tensorboard.
"""Specialized hook to record the metric to Tensorboard.
:param log_dir: Directory of log
:type log_dir: str
:param ranks: ranks of processors
:param ranks: Ranks of processors
:type ranks: typing.List
:param parallel_mode: Parallel mode, defaults to colossalai.context.parallel_mode.ParallelMode.GLOBAL
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
:param mode: Mode of metrics, 'train' and 'test'
:type mode: str
"""
def __init__(self,
......@@ -203,12 +198,12 @@ class TensorboardHook(BaseHook):
@HOOKS.register_module
class LogTimingByEpochHook(LogByEpochHook):
"""Specialized Hook to write timing record to log.
"""Specialized hook to write timing record to log.
:param timer: Timer for the hook
:type timer: colossalai.utils.MultiTimer
:type timer: :class:`colossalai.utils.MultiTimer`
:param logger: Logger for the log
:type logger: colossalai.logging.DistributedLogger
:type logger: :class:`colossalai.logging.DistributedLogger`
:param interval: Recording interval, defaults to 1
:type interval: int, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 10
......@@ -217,9 +212,8 @@ class LogTimingByEpochHook(LogByEpochHook):
:type log_eval: bool, optional
:param ignore_num_train_steps: Number of training steps to ignore, defaults to 0
:type ignore_num_train_steps: int, optional
:param mode: Mode of metrics, 'train' and 'test'
:param trainer: Trainer attached with current hook
"""
def __init__(self,
timer: MultiTimer,
logger: DistributedLogger,
......@@ -285,12 +279,13 @@ class LogMemoryByEpochHook(LogByEpochHook):
:param log_eval: Whether writes in evaluation, defaults to True
:type log_eval: bool, optional
"""
def __init__(self,
logger: DistributedLogger,
interval: int = 1,
priority: int = 10,
log_eval: bool = True,
report_cpu: bool = False, # no reference
report_cpu: bool = False, # no reference
) -> None:
super().__init__(logger=logger, interval=interval, priority=priority)
self._log_eval = log_eval
......
......@@ -15,7 +15,6 @@ class LRSchedulerHook(MetricHook):
:type store_lr_in_state: bool, optional
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 1
:type priority: int, optional
:param trainer: Trainer attached with current hook
"""
def __init__(
self,
......
......@@ -124,6 +124,7 @@ class LossMetric(Metric):
"""
return self.last_step_loss
@staticmethod
def is_better(a, b):
return a < b
......@@ -133,7 +134,7 @@ class LearningRateMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
:param initial_lr: initial learning rate, defaults to 0.0
:param initial_lr: Initial learning rate, defaults to 0.0
:type initial_lr: float, optional
"""
......@@ -153,6 +154,7 @@ class LearningRateMetric(Metric):
def get_accumulated_value(self):
return self.lr
@staticmethod
def is_better(a, b) -> bool:
pass
......@@ -163,8 +165,8 @@ class AccuracyMetric(Metric):
:param epoch_only: Whether the metric only read for the full epoch
:type epoch_only: bool
:param accuracy_func: accuracy function for the classification task
:type accuracy_func: typing.Callable
:param accuracy_func: Accuracy function for the classification task
:type accuracy_func: :class:`typing.Callable`
"""
def __init__(self, epoch_only: bool, accuracy_func: Callable):
......@@ -186,8 +188,8 @@ class AccuracyMetric(Metric):
and labels. It expects the output has logits and labels.
:param logits: The logits output of the model
:param targets: real labels of the dataset
:param batch_size: batch size of the task
:param targets: Real labels of the dataset
:param batch_size: Batch size of the task
"""
if isinstance(logits, (list, tuple)):
logits = logits[0]
......@@ -211,6 +213,7 @@ class AccuracyMetric(Metric):
self.accumulated_correct = all_reduce(self.accumulated_correct, ParallelMode.DATA)
return (self.accumulated_correct / self.accumulated_sum).item()
@staticmethod
def is_better(a, b) -> bool:
return a > b
......@@ -223,8 +226,6 @@ class MetricHook(BaseHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front
:type priority: int
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(
......@@ -245,8 +246,6 @@ class LossHook(MetricHook):
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(self, priority: int = 0):
......@@ -288,8 +287,6 @@ class AccuracyHook(MetricHook):
:type accuracy_func: typing.Callable
:param priority: Priority in the printing, hooks with small priority will be printed in front, defaults to 0
:type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(self, accuracy_func: Callable, priority: int = 0):
......@@ -319,8 +316,6 @@ class ThroughputMetric(Metric):
:param epoch_only: epoch only
:type epoch_only: bool
:param num_samples: number of samples
:param time: time
"""
def __init__(self, epoch_only: bool):
super().__init__(epoch_only=epoch_only)
......@@ -353,6 +348,7 @@ class ThroughputMetric(Metric):
self.accumulated_num_samples = all_reduce(self.accumulated_num_samples, ParallelMode.DATA)
return (self.accumulated_num_samples / (self.accumulated_used_time + 1e-12)).item()
@staticmethod
def is_better(a, b) -> bool:
pass
......@@ -363,8 +359,6 @@ class ThroughputHook(MetricHook):
:param priority: priority of throughput hook, defaults to 10
:type priority: int, optional
:param trainer: Trainer attached with current hook
:type trainer: Trainer
"""
def __init__(self, priority: int = 10):
super().__init__(priority)
......
......@@ -108,10 +108,10 @@ class CheckpointFunction(torch.autograd.Function):
def checkpoint(function, *args):
'''Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
"""Checkpoint the computation while preserve the rng states, modified from Pytorch torch.utils.checkpoint
:param function: describe the forward pass function. It should know how to handle the input tuples.
:param args: tuple containing inputs to the function
:return: Output of running function on \*args
'''
:param function: Describe the forward pass function. It should know how to handle the input tuples.
:param args: Tuple containing the parameters of the function
:return: Output of running function with provided args
"""
return CheckpointFunction.apply(function, *args)
......@@ -19,9 +19,8 @@ __all__ = [
def unwrap_config(config: Config):
'''
unwrap Config objects to normal dicts
'''
"""Unwrap Config objects to normal dicts
"""
config_dict = dict()
for k, v in config.items():
if isinstance(v, dict):
......@@ -53,18 +52,18 @@ def _get_standard_checkpoint_filename(epoch: int, suffix: str = ''):
def get_checkpoint_path(checkpoint_dir: str, epoch: int, suffix: str = ''):
'''This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
"""This is a function to generate the checkpoint path from the (checkpoint_dir, epoch, suffix, gpu_parallel_rank) tuple.
This is useful during generation and recuperation of the checkpoint.
:param checkpoint_dir: set up a directory for saving checkpoints
:param checkpoint_dir: Set up a directory for saving checkpoints
:type checkpoint_dir: str
:param epoch: epoch number (indicate how many epochs have you trained this model)
:param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: checkpoint path to be generated
:return: Checkpoint path to be generated
:rtype: path
'''
"""
ckpt_filename = _get_standard_checkpoint_filename(epoch, suffix)
return os.path.join(checkpoint_dir, ckpt_filename)
......@@ -77,30 +76,30 @@ def _ensure_directory_exists(filename: str):
def get_latest_checkpoint_pattern(suffix: str = ''):
'''Generate Regular expression of latest checkpoint's pattern
"""Generate Regular expression of latest checkpoint's pattern
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:return: checkpoint pattern
:return: Checkpoint pattern
:rtype: regular expression
'''
"""
ranks_name = _get_ranks_name()
ckpt_pattern = re.compile(f'epoch(\d+)-{ranks_name}{suffix}\.pt')
return ckpt_pattern
def get_latest_checkpoint_path(checkpoint_dir: str, suffix: str = ''):
'''This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
"""This is a function to retrieve the latest checkpoint path from the (checkpoint_dir, suffix, gpu_parallel_rank) tuple.
This is useful during recuperation of the checkpoint, especially when you do not know the epoch number.
:param checkpoint_dir: directory for saving checkpoints
:param checkpoint_dir: Directory for saving checkpoints
:type checkpoint_dir: str
:param suffix: additional notation to specify the model or checkpoint, defaults to ''
:param suffix: Additional notation to specify the model or checkpoint, defaults to ''
:type suffix: str, optional
:raises FileNotFoundError: raise error when we cannot find the latest checkpoint file with inputs given
:return: the latest checkpoint path to be retrieved
:raises FileNotFoundError: Raise error when we cannot find the latest checkpoint file with inputs given
:return: The latest checkpoint path to be retrieved
:rtype: path
'''
"""
CKPT_NAME_PAT = get_latest_checkpoint_pattern(suffix=suffix)
last_epoch = -1
......@@ -128,22 +127,22 @@ def save_checkpoint(checkpoint_path: str,
optimizer: torch.optim.Optimizer,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
**kwargs):
'''Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
"""Given a directory to store the checkpoints, saves all the training components' parameters or buffers, such as model, optimizer, lr_scheduler and etc. into a checkpoint dictionary.
This method can be used for both colosalai nn.BaseModel and normal pytorch nn.Module.
:param checkpoint_path: set up a directory for saving checkpoints
:param checkpoint_path: Set up a directory for saving checkpoints
:type checkpoint_path: str
:param epoch: epoch number (indicate how many epochs have you trained this model)
:param epoch: Epoch number (indicate how many epochs have you trained this model)
:type epoch: int
:param model: model to be registered
:param model: Model to be registered
:type model: torch.nn.Module
:param optimizer: optimizer to be registered
:param optimizer: Optimizer to be registered
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to be registered, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
'''
"""
# for compatibility with normal pytorch nn.Module
if hasattr(model, 'state_dict_for_save_checkpoint'):
model_sd = model.state_dict_for_save_checkpoint()
......@@ -170,31 +169,31 @@ def load_checkpoint(checkpoint_path: str,
lr_scheduler: torch.optim.lr_scheduler._LRScheduler = None,
finetune: bool = False,
strict: bool = True) -> Tuple:
'''Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
"""Loads the checkpoint file.
If finetune is False, then we intend to continue/resume the training process from the checkpoint given.
So we copy parameters and buffers from state_dict into these modules(model, optimizer,lr_scheduler) and its descendants.
If finetune is True, then only the weights and buffers of model should be reload.
If strict is True, then the keys of state_dict must exactly match the keys returned by this module’s state_dict() function.
:param checkpoint_path: the exact and matched checkpoint_path directory to retrieve appropriate state_dict
:param checkpoint_path: The exact and matched checkpoint_path directory to retrieve appropriate state_dict
:type checkpoint_path: str
:param model: model to reload parameters and buffers
:param model: Model to reload parameters and buffers
:type model: torch.nn.Module
:param optimizer: optimizer to recuperate
:type optimizer: torch.optim.Optimizer
:param optimizer: Optimizer to recuperate
:type optimizer: torch.optim.Optimizer
:param lr_scheduler: lr_scheduler to recuperate, defaults to None
:type lr_scheduler: torch.optim.lr_scheduler._LRScheduler, optional
:param finetune: whether to finetune the model with new dataset or continue the pre-training, defaults to False
:param finetune: Whether to finetune the model with new dataset or continue the pre-training, defaults to False
:type finetune: bool, optional
:param strict: whether to strictly enforce that the keys in
:param strict: Whether to strictly enforce that the keys in
:attr:`state_dict` of the checkpoint match the names of
parameters and buffers in model., defaults to True
:type strict: bool, optional
:raises ValueError: raise error if the model/optimizer cannot successfully be recuperated
:raises ValueError: Raise error if the model/optimizer cannot successfully be recuperated
:return: (the epoch number of the checkpoint retrieved, the checkpoint retrieved)
:rtype: Tuple
'''
"""
# Load the checkpoint.
checkpoint = torch.load(checkpoint_path, map_location='cpu')
try:
......
......@@ -6,6 +6,8 @@ import socket
import torch
from torch._six import inf
import colossalai.context.parallel_mode
try:
import colossal_C
except:
......@@ -23,11 +25,13 @@ from .multi_tensor_apply import multi_tensor_applier
def print_rank_0(msg: str, logger=None):
'''Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
"""Print messages and save logs(optional). This is executed only if you are the rank-0 gpu.
:param msg: A str message to output
:param logger: python logger object, defaults to None
'''
:param msg: A string message to output
:type msg: str
:param logger: Python logger object, defaults to None
:type logger: optional
"""
if gpc.get_global_rank() == 0:
if logger is None:
print(msg, flush=True)
......@@ -48,10 +52,13 @@ def free_port():
def sync_model_param(model, parallel_mode):
'''Make sure data parameters are consistent during Data Parallel Mode
"""Make sure data parameters are consistent during Data Parallel Mode
:param model: A pyTorch nn.model on whose parameters you check the consistency
'''
:param parallel_mode: Parallel mode to be checked
:type model: torch.nn.Module
:type parallel_mode: colossalai.context.ParallelMode
"""
if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1:
for param in model.parameters():
ranks = gpc.get_ranks_in_group(parallel_mode)
......@@ -124,18 +131,17 @@ def _calc_lp(grads, norm_type):
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
"""Clips gradient norm of an iterable of parameters whose gradients are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
This is adapted from :func:`torch.nn.utils.clip_grad.clip_grad_norm_` and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
:param parameters: an iterable of Tensors or a single Tensor that will have gradients normalized
:param parameters: An iterable of Tensors or a single Tensor that will have gradients normalized
:type parameters: (Iterable[Tensor] or Tensor)
:param max_norm: max norm of the gradients
:param max_norm: Max norm of the gradients
:type max_norm: float or int
:param norm_type: type of the used p-norm. Can be ``'inf'`` for infinity norm.
:param norm_type: Type of the used p-norm. Can be ``'inf'`` for infinity norm.
:type norm_type: float or int
:return: Total norm of the parameters (viewed as a single vector).
......
......@@ -5,10 +5,10 @@ import torch
def set_to_cuda(models):
'''Send model to gpu.
"""Send model to gpu.
:param models: nn.module or a list of module
'''
"""
if isinstance(models, list) and len(models) > 1:
ret = []
for model in models:
......@@ -21,9 +21,8 @@ def set_to_cuda(models):
def get_current_device():
'''
Returns the index of a currently selected device (gpu/cpu).
'''
"""Returns the index of a currently selected device (gpu/cpu).
"""
if torch.cuda.is_available():
return torch.cuda.current_device()
else:
......@@ -31,18 +30,16 @@ def get_current_device():
def synchronize():
'''
Similar to cuda.synchronize().
"""Similar to cuda.synchronize().
Waits for all kernels in all streams on a CUDA device to complete.
'''
"""
if torch.cuda.is_available():
torch.cuda.synchronize()
def empty_cache():
'''
Similar to cuda.empty_cache()
"""Similar to cuda.empty_cache()
Releases all unoccupied cached memory currently held by the caching allocator.
'''
"""
if torch.cuda.is_available():
torch.cuda.empty_cache()
......@@ -21,13 +21,15 @@ T_co = TypeVar('T_co', covariant=True)
class DataParallelSampler(Sampler):
"""A data sampler for distributed data parallelism
:param dataset: a Dataset instance
:param dataset: A Dataset instance
:type dataset: torch.utils.data.Dataset
:param shuffle: whether to shuffle data, defaults to False
:param shuffle: Whether to shuffle data, defaults to False
:type shuffle: bool, optional
:param seed: the random seed, defaults to 0
:param seed: The random seed, defaults to 0
:type seed: int, optional
:param drop_last: set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller, defaults to False
:param drop_last: Set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch
size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller,
defaults to False
:type drop_last: bool, optional
"""
......@@ -116,19 +118,18 @@ def get_dataloader(dataset,
pin_memory=False,
num_workers=0,
**kwargs):
'''Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not)
.. note: when pipeline parallel is enabled, shuffle cannot be True
as it will result in mismatch between input data on the 1st
stage and label on the last stage
.. note:: When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data
on the 1st stage and label on the last stage
:param dataset: a :class:utils.data.dataset dataset
:param shuffle: whether to shuffle the dataset
:param seed: random worker seed, defaults to 1024
:param add_sampler: add DistributedDataParallelSampelr to the dataset
:param drop_last: drop the last incomplete batch of data
:param pin_memory: whether to pin memory address in CPU memory
:param num_workers: number of worker threads for this dataloader
:param dataset: A :class:`utils.data.dataset dataset`
:param shuffle: Whether to shuffle the dataset
:param seed: Random worker seed, defaults to 1024
:param add_sampler: Add DistributedDataParallelSampelr to the dataset
:param drop_last: Drop the last incomplete batch of data
:param pin_memory: Whether to pin memory address in CPU memory
:param num_workers: Number of worker threads for this dataloader
:type dataset: :class:`torch.utils.data.Dataset`
:type shuffle: bool, optional. Default is False
......@@ -138,9 +139,9 @@ def get_dataloader(dataset,
:type pin_memory: bool, optional. Default is False
:type num_workers: int, optional. Default is 0
:return: a object of :class:`torch.utils.data.DataLoader`
:return: A object of :class:`torch.utils.data.DataLoader`
:rtype: :class:`torch.utils.data.DataLoader`
'''
"""
_kwargs = kwargs.copy()
if add_sampler and gpc.is_initialized(ParallelMode.DATA) and gpc.get_world_size(ParallelMode.DATA) > 1:
......
......@@ -17,11 +17,11 @@ class GradAccumOptimizer(ColossalaiOptimizer):
"""A wrapper for the optimizer to enable gradient accumulation by skipping the steps
before accumulation size is reached
:param optim: your optimizer object
:param optim: Your optimizer object
:type optim: :class:`torch.optim.Optimizer`
:param accumulate_size: the number of steps to accumulate gradients
:type accumualate_size: int
:param model: your model object to check if it is DDP for special handling of no_sync() context
:param accumulate_size: The number of steps to accumulate gradients
:type accumulate_size: int
:param model: Your model object to check if it is DDP for special handling of no_sync() context
:type model: :class:`torch.nn.Module`
"""
......@@ -75,7 +75,7 @@ class GradAccumOptimizer(ColossalaiOptimizer):
self.optim.backward_by_grad(tensor, grad)
class GradAccumDataloader():
class GradAccumDataloader:
"""A wrapper for dataloder to enable gradient accumulation by dropping the last incomplete steps.
For example, if a dataloader has 10 batches of data and accumulate size is 4. The model paramters will
......@@ -83,10 +83,10 @@ class GradAccumDataloader():
Thus, they will be automatically skipped by this class. If the dataloader is not standard PyTorch dataloader,
(e.g. Dali dataloader), this class will automatically consume (load data for nothing) the remaining 2 batches.
:param dataloader: your dataloader object
:param dataloader: Your dataloader object
:type dataloader: Iterable
:param accumulate_size: the number of steps to accumulate gradients
:type accumualate_size: int
:param accumulate_size: The number of steps to accumulate gradients
:type accumulate_size: int
"""
......@@ -127,10 +127,10 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
"""A wrapper for the LR scheduler to enable gradient accumulation by skipping the steps
before accumulation size is reached
:param lr_scheduler: your lr scheduler object
:param lr_scheduler: Your lr scheduler object
:type lr_scheduler: :class:`torch.optim.lr_scheduler._LRScheduler`
:param accumulate_size: the number of steps to accumulate gradients
:type accumualate_size: int
:param accumulate_size: The number of steps to accumulate gradients
:type accumulate_size: int
"""
......@@ -170,14 +170,14 @@ class GradAccumLrSchedulerByStep(_LRScheduler):
self.lr_scheduler.load_state_dict(state_dict)
class GradAccumGradientHandler():
class GradAccumGradientHandler:
"""A wrapper for the gradient handler to enable gradient accumulation by skipping the steps
before accumulation size is reached
:param grad_handler: your gradient handler object
:param grad_handler: Your gradient handler object
:type grad_handler: :class:`colossalai.engine.BaseGradientHandler`
:param accumulate_size: the number of steps to accumulate gradients
:type accumualate_size: int
:param accumulate_size: The number of steps to accumulate gradients
:type accumulate_size: int
"""
......
......@@ -12,34 +12,36 @@ from colossalai.logging import get_dist_logger
def bytes_to_GB(val, decimal=2):
'''A byte-to-Gigabyte converter, defaultly using binary notation.
"""A byte-to-Gigabyte converter, defaultly using binary notation.
:param val: X bytes to convert
:return: X' GB
'''
"""
return round(val / (1024 * 1024 * 1024), decimal)
def bytes_to_MB(val, decimal=2):
'''A byte-to-Megabyte converter, defaultly using binary notation.
"""A byte-to-Megabyte converter, defaultly using binary notation.
:param val: X bytes to convert
:return: X' MB
'''
"""
return round(val / (1024 * 1024), decimal)
def report_memory_usage(message, logger=None, report_cpu=False):
'''Calculate and print RAM usage (in GB)
"""Calculate and print RAM usage (in GB)
:param message: a prefix message to add in the log
:param message: A prefix message to add in the log
:type message: str
:param logger: an instance of :class:`colossalai.logging.DistributedLogger`
:type logger: :class:`colossalai.logging.DistributedLogger`
:param report_cpu: whether to report CPU memory
:type report_cpu: bool
:raises EnvironmentError: raise error if no distributed environment has been initialized
'''
:param logger: An instance of :class:`colossalai.logging.DistributedLogger`
:type logger: :class:`colossalai.logging.DistributedLogger`, optional
:param report_cpu: Whether to report CPU memory
:type report_cpu: bool, optional
:raises EnvironmentError: Raise error if no distributed environment has been initialized
"""
if not gpc.is_initialized(ParallelMode.GLOBAL):
raise EnvironmentError("No distributed environment is initialized")
gpu_allocated = bytes_to_MB(torch.cuda.memory_allocated())
gpu_max_allocated = bytes_to_MB(torch.cuda.max_memory_allocated())
......
......@@ -5,7 +5,7 @@ class MultiTensorApply(object):
"""
Apply an operation to a list of tensors efficiently
:param chunk_size: size of a chunk
:param chunk_size: Size of a chunk
:type chunk_size: int
"""
......@@ -22,7 +22,7 @@ class MultiTensorApply(object):
MultiTensorApply.import_err = err
def check_avail(self):
if MultiTensorApply.available == False:
if not MultiTensorApply.available:
raise RuntimeError(
"Attempted to call MultiTensorApply method, but MultiTensorApply "
"is not available, possibly because Apex was installed without "
......
......@@ -6,9 +6,8 @@ from .cuda import synchronize
class Timer:
'''
A timer object which helps to log the execution times, and provides different tools to assess the times.
'''
"""A timer object which helps to log the execution times, and provides different tools to assess the times.
"""
def __init__(self):
self._started = False
......@@ -21,20 +20,21 @@ class Timer:
return len(self._history) != 0
def start(self):
'''Fisrtly synchronize cuda, reset the clock and then start the timer.
'''
"""Fisrtly synchronize cuda, reset the clock and then start the timer.
"""
self._elapsed = 0
synchronize()
self._start_time = time.time()
self._started = True
def stop(self, keep_in_history: bool = False):
'''Stop the timer and record the start-stop time interval.
:param keep_in_history: whether does it record into history each start-stop interval, defaults to False
"""Stop the timer and record the start-stop time interval.
:param keep_in_history: Whether does it record into history each start-stop interval, defaults to False
:type keep_in_history: bool, optional
:return: start-stop interval
:return: Start-stop interval
:rtype: int
'''
"""
synchronize()
end_time = time.time()
elapsed = end_time - self._start_time
......@@ -45,79 +45,90 @@ class Timer:
return elapsed
def get_history_mean(self):
'''mean of all history start-stop time intervals.
:return: mean of time intervals
"""Mean of all history start-stop time intervals.
:return: Mean of time intervals
:rtype: int
'''
"""
return sum(self._history) / len(self._history)
def get_history_sum(self):
'''add up all the start-stop time intervals.
:return: sum of time intervals
"""Add up all the start-stop time intervals.
:return: Sum of time intervals
:rtype: int
'''
"""
return sum(self._history)
def get_elapsed_time(self):
'''return the last start-stop time interval. *use it only when timer is not in progress*
:return: the last time interval
"""Return the last start-stop time interval.
.. note:: Use it only when timer is not in progress
:return: The last time interval
:rtype: int
'''
"""
assert not self._started, 'Timer is still in progress'
return self._elapsed
def reset(self):
'''clear up the timer and its history
'''
"""Clear up the timer and its history
"""
self._history = []
self._started = False
self._elapsed = 0
class MultiTimer:
'''An object contains multiple timers
"""An object contains multiple timers
:param on: whether the timer is enabled. Default is True
:type on: bool
'''
:param on: Whether the timer is enabled. Default is True
:type on: bool, optional
"""
def __init__(self, on: bool = True):
self._on = on
self._timers = dict()
def start(self, name: str):
'''Start namely one of the timers
:param name: timer's key
"""Start namely one of the timers
:param name: Timer's key
:type name: str
'''
"""
if self._on:
if name not in self._timers:
self._timers[name] = Timer()
return self._timers[name].start()
def stop(self, name: str, keep_in_history: bool):
'''Stop namely one of the timers.
:param name: timer's key
:param keep_in_history: whether does it record into history each start-stop interval
"""Stop namely one of the timers.
:param name: Timer's key
:type name: str
:param keep_in_history: Whether does it record into history each start-stop interval
:type keep_in_history: bool
'''
"""
if self._on:
return self._timers[name].stop(keep_in_history)
else:
return None
def get_timer(self, name):
'''Get timer by its name (from multitimer)
:param name: timer's key
:return: timer with the name you give correctly
:rtype: Timer
'''
"""Get timer by its name (from multitimer)
:param name: Timer's key
:return: Timer with the name you give correctly
:rtype: Timer
"""
return self._timers[name]
def reset(self, name=None):
'''Reset timers.
:param name: if name is designated, the named timer will be reset and others will not, defaults to None
'''
"""Reset timers.
:param name: If name is designated, the named timer will be reset and others will not, defaults to None
:type name: optional
"""
if self._on:
if name is not None:
self._timers[name].reset()
......
......@@ -13,17 +13,17 @@ from .zero_redundancy_optimizer_level_3 import ZeroRedundancyOptimizer_Level_3
def convert_to_zero(model: nn.Module,
optimizer: Optimizer,
level: int,
zero_config):
zero_config: dict):
"""
A helper function to integrate the model and optimizer with ZeRO optimizer and off-loading
:param model: your model object
:param model: Your model object
:type model: :class:`torch.nn.Module`
:param optimizer: your optimizer object
:param optimizer: Your optimizer object
:type optimizer: :class:`torch.optim.Optimizer`
:param level: optimizer level, can be 2 or 3
:param level: Optimizer level, can be 2 or 3
:type level: int
:param zero_config: configuration for zero
:param zero_config: Configuration for zero
:type zero_config: dict
:return: (model, optimizer)
......
......@@ -95,8 +95,11 @@ class DynamicLossScaler(LossScalerBase):
always using the highest loss scale possible without incurring overflow.
Args:
init_scale (float, optional, default=2**32): Initial loss scale attempted by :class:`DynamicLossScaler.`
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before increasing the loss scale.
scale_factor (float, optional, default=2.0): Factor used when adjusting the loss scale. If an overflow is
encountered, the loss scale is readjusted to loss scale/``scale_factor``. If ``scale_window`` consecutive
iterations take place without an overflow, the loss scale is readjusted to loss_scale*``scale_factor``.
scale_window (int, optional, default=1000): Number of consecutive iterations without an overflow to wait before
increasing the loss scale.
"""
def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment