Commit ec5086c4 authored by Liang Bowen's avatar Liang Bowen Committed by アマデウス
Browse files

Refactored docstring to google style

parent 53b1b6e3
...@@ -12,21 +12,27 @@ from .naive_amp import convert_to_naive_amp ...@@ -12,21 +12,27 @@ from .naive_amp import convert_to_naive_amp
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None): def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
"""A helper function to wrap training components with Torch AMP modules """A helper function to wrap training components with Torch AMP modules.
:param model: your model object Args:
:type model: :class:`torch.nn.Module` param model (:class:`torch.nn.Module`): your model object.
:param optimizer: your optimizer object optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
:type optimizer: :class:`torch.optim.Optimizer` criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
:param criterion: your loss function object mode (:class:`colossalai.amp.AMP_TYPE`): amp mode.
:type criterion: :class:`torch.nn.modules.loss._Loss` amp_config (:class:`colossalai.context.Config` or dict): configuration for different amp modes
:param mode: amp mode
:type mode: :class:`colossalai.amp.AMP_TYPE` Returns:
:param amp_config: configuration for different amp modes A tuple (model, optimizer, criterion).
:type amp_config: :class:`colossalai.context.Config` or dict
Note:
:return: (model, optimizer, criterion) ``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
:rtype: Tuple for more details about ``amp_config``.
For ``apex_amp``, please check
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
For ``naive_amp``, please check
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
For ``torch_amp``, please check
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
""" """
assert isinstance(mode, AMP_TYPE), \ assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}' f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
......
...@@ -4,17 +4,33 @@ from torch.optim import Optimizer ...@@ -4,17 +4,33 @@ from torch.optim import Optimizer
def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config): def convert_to_apex_amp(model: nn.Module, optimizer: Optimizer, amp_config):
"""A helper function to wrap training components with Apex AMP modules r"""A helper function to wrap training components with Apex AMP modules
:param model: your model object Args:
:type model: :class:`torch.nn.Module` model (:class:`torch.nn.Module`): your model object.
:param optimizer: your optimizer object optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
:type optimizer: :class:`torch.optim.Optimizer` amp_config (:class: colossalai.context.Config or dict): configuration for initializing apex_amp.
:param amp_config: configuration for nvidia apex
:type amp_config: :class:`colossalai.context.Config` or dict
:return: (model, optimizer) The ``amp_config`` should include parameters below:
:rtype: Tuple ::
enabled (bool, optional, default=True)
opt_level (str, optional, default="O1")
cast_model_type (``torch.dtype``, optional, default=None)
patch_torch_functions (bool, optional, default=None)
keep_batchnorm_fp32 (bool or str, optional, default=None
master_weights (bool, optional, default=None)
loss_scale (float or str, optional, default=None)
cast_model_outputs (torch.dtype, optional, default=None)
num_losses (int, optional, default=1)
verbosity (int, default=1)
min_loss_scale (float, default=None)
max_loss_scale (float, default=2.**24)
Returns:
Tuples: A tuple (model, optimizer).
More details about ``amp_config`` refer to `amp_config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
""" """
import apex.amp as apex_amp import apex.amp as apex_amp
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config) model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
......
...@@ -21,8 +21,8 @@ class ApexAMPOptimizer(ColossalaiOptimizer): ...@@ -21,8 +21,8 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
def backward(self, loss: Tensor): def backward(self, loss: Tensor):
"""Backward pass to get all gradients """Backward pass to get all gradients
:param loss: Loss computed by a loss function Args:
:type loss: torch.Tensor loss (torch.Tensor): Loss computed by a loss function
""" """
with apex_amp.scale_loss(loss, self.optim) as scaled_loss: with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward() scaled_loss.backward()
...@@ -30,10 +30,9 @@ class ApexAMPOptimizer(ColossalaiOptimizer): ...@@ -30,10 +30,9 @@ class ApexAMPOptimizer(ColossalaiOptimizer):
def clip_grad_norm(self, model: nn.Module, max_norm: float): def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Clip gradients' norm """Clip gradients' norm
:param model: Your model object Args:
:type model: torch.nn.Module model (torch.nn.Module): Your model object
:param max_norm: The max norm value for gradient clipping max_norm (float): The max norm value for gradient clipping
:type max_norm: float
""" """
if max_norm > 0: if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm) clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
...@@ -4,20 +4,30 @@ from torch.optim import Optimizer ...@@ -4,20 +4,30 @@ from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
from .grad_scaler import DynamicGradScaler, ConstantGradScaler from .grad_scaler import DynamicGradScaler, ConstantGradScaler
from ._fp16_optimizer import FP16Optimizer
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
"""A helper function to wrap training components with naive AMP modules """A helper function to wrap training components with naive AMP modules. In this mode,
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
which is equivalent to Apex O3.
:param model: your model object Args:
:type model: :class:`torch.nn.Module` model (:class:`torch.nn.Module`): your model object
:param optimizer: your optimizer object optimizer (:class:`torch.optim.Optimizer`): your optimizer object
:type optimizer: :class:`torch.optim.Optimizer` amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
:param amp_config: configuration for naive mode amp
:type amp_config: :class:`colossalai.context.Config` or dict
:return: (model, optimizer)
:rtype: Tuple The ``amp_config`` should contain parameters below:
:
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
Returns:
Tuples: A tuple (model, optimizer)
""" """
if isinstance(model, nn.ModuleList): if isinstance(model, nn.ModuleList):
# interleaved pipeline # interleaved pipeline
...@@ -46,4 +56,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config): ...@@ -46,4 +56,4 @@ def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
return model, optimizer return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer'] __all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
...@@ -42,24 +42,13 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None): ...@@ -42,24 +42,13 @@ def _multi_tensor_copy_this_to_that(this, that, overflow_buf=None):
class FP16Optimizer(Optimizer): class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types. """Float16 optimizer for fp16 and bf16 data types.
:param optimizer: base optimizer such as Adam or SGD Args:
:type optimizer: torch.optim.Optimizer optimizer (torch.optim.Optimizer): base optimizer such as Adam or SGD
:param clip_grad: clip gradeints with this global L2 norm. Note that clipping is ignored if clip_grad == 0 grad_scaler (BaseGradScaler): grad scaler for gradient chose in
:type param clip_grad: float ``constant_grad_scaler`` or ``dynamic_grad_scaler``.
:param log_num_zeros_in_grad: return number of zeros in the gradients. clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0.
:type log_num_zeros_in_grad: bool Note that clipping is ignored if clip_grad == 0
:param initial_scale: initial scale of gradient scaler verbose (bool, optional): if set to `True`, will print debug info. Default False.
:type initial_scale: int
:param growth_factor: the growth rate of loss scale
:type growth_factor: int
:param backoff_factor: the decrease rate of loss scale
:type backoff_factor: float
:param hysterisis: delay shift in dynamic loss scaling
:type hysterisis: int
:param max_scale: maximum loss scale allowed
:type max_scale: int
:param verbose: if set to `True`, will print debug info
:type verbose: bool
""" """
def __init__(self, def __init__(self,
......
...@@ -18,11 +18,15 @@ from ._fp16_optimizer import FP16Optimizer ...@@ -18,11 +18,15 @@ from ._fp16_optimizer import FP16Optimizer
class NaiveAMPOptimizer(ColossalaiOptimizer): class NaiveAMPOptimizer(ColossalaiOptimizer):
"""A wrapper class for optimizer to cast all parameters to fp16 """A wrapper class for optimizer to cast all parameters to fp16
:param optim: A normal optimizer like Adam or SGD Args:
:param args: Args used to initialize FP16 optimizer optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD.
:param kwargs: Kwargs used to initialize FP16 optimizer grad_scaler (BaseGradScaler): grad scaler for gradient chose in
``constant_grad_scaler`` or ``dynamic_grad_scaler``.
:type optim: torch.optim.Optimizer clip_grad_norm (float, optional): clip gradients with this global L2 norm. Default 0.
verbose (bool, optional): if set to `True`, will print debug info. Default False.
Note:
clipping is ignored if ``clip_grad_norm`` equals 0.
""" """
def __init__(self, optim: Optimizer, *args, **kwargs): def __init__(self, optim: Optimizer, *args, **kwargs):
...@@ -40,8 +44,19 @@ class NaiveAMPOptimizer(ColossalaiOptimizer): ...@@ -40,8 +44,19 @@ class NaiveAMPOptimizer(ColossalaiOptimizer):
class NaiveAMPModel(nn.Module): class NaiveAMPModel(nn.Module):
"""A wrapper class for model to cast the model into fp16 and r"""A wrapper class for model to cast the model into fp16 and
automatically cast the input and output automatically cast the input and output
Args:
model (torch.nn.Module): torch.nn.Module to be wrapped.
output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True)
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module.
(Default: ``ParallelMode.DATA``)
sync_buffer (bool, optional): whether to synchronize buffer. (Default: True)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
def __init__(self, def __init__(self,
......
...@@ -10,18 +10,25 @@ def convert_to_torch_amp(model: nn.Module, ...@@ -10,18 +10,25 @@ def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer, optimizer: Optimizer,
criterion: Optional[_Loss] = None, criterion: Optional[_Loss] = None,
amp_config: Optional[Config] = None): amp_config: Optional[Config] = None):
"""A helper function to wrap training components with Torch AMP modules """A helper function to wrap training components with Pytorch AMP modules
:param model: your model object Args:
:type model: :class:`torch.nn.Module` model (:class:`torch.nn.Module`): your model object.
:param optimizer: your optimizer object optimizer (:class:`torch.optim.Optimizer`): your optimizer object
:type optimizer: :class:`torch.optim.Optimizer` criterion (:class:`torch.nn.modules.loss._Loss`, optional): your loss function object
:param criterion: your loss function object amp_config (:class:`colossalai.context.Config` or dict, optional): configuration for Pytorch AMP.
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional
:param amp_config: configuration for different amp modes The ``amp_config`` should include parameters below:
:type amp_config: :class:`colossalai.context.Config` or dict, optional ::
:return: (model, optimizer, criterion)
:rtype: Tuple init_scale (float, optional, default=2.**16)
growth_factor (float, optional, default=2.0)
backoff_factor (float, optional, default=0.5)
growth_interval (int, optional, default=2000)
enabled (bool, optional, default=True)
Returns:
A tuple (model, optimizer, criterion)
""" """
model = TorchAMPModel(model) model = TorchAMPModel(model)
if amp_config is None: if amp_config is None:
......
...@@ -14,13 +14,19 @@ from colossalai.utils import clip_grad_norm_fp32 ...@@ -14,13 +14,19 @@ from colossalai.utils import clip_grad_norm_fp32
class TorchAMPOptimizer(ColossalaiOptimizer): class TorchAMPOptimizer(ColossalaiOptimizer):
"""A wrapper class which integrate pytorch amp with an optimizer """A wrapper class which integrate Pytorch AMP with an optimizer
:param optim: A normal optimizer like Adam or SGD Args:
:param args: Args used to initialize gradient scaler optim (torch.optim.Optimizer): A normal optimizer like Adam or SGD.
:param kwargs: Kwargs used to initialize gradient scaler init_scale (float, optional, default=2.**16): Initial scale factor.
growth_factor (float, optional, default=2.0): Factor by which the scale is multiplied during
:type optim: torch.optim.Optimizer :meth:`update` if no inf/NaN gradients occur for ``growth_interval`` consecutive iterations.
backoff_factor (float, optional, default=0.5): Factor by which the scale is multiplied during
:meth:`update` if inf/NaN gradients occur in an iteration.
growth_interval (int, optional, default=2000): Number of consecutive iterations without inf/NaN gradients
that must occur for the scale to be multiplied by ``growth_factor``.
enabled (bool, optional, default=True): If ``False``, disables gradient scaling. :meth:`step` simply
invokes the underlying ``optimizer.step()``, and other methods become no-ops.
""" """
def __init__(self, optim: Optimizer, *args, **kwargs): def __init__(self, optim: Optimizer, *args, **kwargs):
...@@ -30,8 +36,8 @@ class TorchAMPOptimizer(ColossalaiOptimizer): ...@@ -30,8 +36,8 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
def backward(self, loss: Tensor): def backward(self, loss: Tensor):
"""Backward with torch amp gradient scaler """Backward with torch amp gradient scaler
:param loss: Loss computed by a loss function Args:
:type loss: torch.Tensor loss (torch.Tensor): Loss computed by a loss function
""" """
self.scaler.scale(loss).backward() self.scaler.scale(loss).backward()
...@@ -44,10 +50,9 @@ class TorchAMPOptimizer(ColossalaiOptimizer): ...@@ -44,10 +50,9 @@ class TorchAMPOptimizer(ColossalaiOptimizer):
def clip_grad_norm(self, model: nn.Module, max_norm: float): def clip_grad_norm(self, model: nn.Module, max_norm: float):
"""Apply gradient clipping to the model parameters """Apply gradient clipping to the model parameters
:param model: Your model object Args:
:type model: torch.nn.Module model (torch.nn.Module): Your model object
:param max_norm: Max norm value for gradient clipping max_norm (float): Max norm value for gradient clipping
:type max_norm: float
""" """
if max_norm > 0.0: if max_norm > 0.0:
self.scaler.unscale_(self.optim) self.scaler.unscale_(self.optim)
...@@ -71,8 +76,8 @@ class TorchAMPModel(nn.Module): ...@@ -71,8 +76,8 @@ class TorchAMPModel(nn.Module):
class TorchAMPLoss(nn.Module): class TorchAMPLoss(nn.Module):
"""A wrapper class for a criterion object which computes the loss in mixed-precision context """A wrapper class for a criterion object which computes the loss in mixed-precision context
:param loss: A loss function object Args:
:type loss: torch.nn.modules.loss._Loss loss (torch.nn.modules.loss._Loss): A loss function object
""" """
def __init__(self, loss: _Loss): def __init__(self, loss: _Loss):
......
...@@ -10,34 +10,40 @@ from colossalai.registry import * ...@@ -10,34 +10,40 @@ from colossalai.registry import *
def build_from_config(module, config: dict): def build_from_config(module, config: dict):
"""Returns an object of :class:`module` constructed from `config`. """Returns an object of :class:`module` constructed from `config`.
:param module: A python or user-defined class Args:
:type module: class module: A python or user-defined class
:param config: A python dict containing information used in the construction config: A python dict containing information used in the construction of the return object
of the return object
:type config: dict Returns: An ``object`` of interest
:raises AssertionError: Raises an AssertionError if `module` is not a class
:return: An object of interest Raises:
:rtype: Object AssertionError: Raises an AssertionError if `module` is not a class
""" """
assert inspect.isclass(module), 'module must be a class' assert inspect.isclass(module), 'module must be a class'
return module(**config) return module(**config)
def build_from_registry(config, registry: Registry): def build_from_registry(config, registry: Registry):
"""Returns an object constructed from `config`, the type of the object r"""Returns an object constructed from `config`, the type of the object
is specified by `registry`. is specified by `registry`.
:param config: A python dict or a :class:`colossalai.context.Config` object Note:
containing information used in the construction of the return object the `config` is used to construct the return object such as `LAYERS`,
:type config: dict or :class:`colossalai.context.colossalai.context.Config` `OPTIMIZERS` and other support types in `registry`. The `config` should contain
:param registry: A registry specifying the type of the return object all required parameters of corresponding object. The details of support
:type registry: :class:`Registry` types in `registry` and the `mod_type` in `config` could be found in
:raises AssertionError: Raises an AssertionError if `registry` is not an object `registry <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/registry/__init__.py>`_.
of :class:`Registry` or `mod_type` in `config` is not found in `registry`
:raises Exception: Raises an Exception if an error occurred when building Args:
from registry config (dict or :class:`colossalai.context.colossalai.context.Config`): information
:return: An object specified by `registry` used in the construction of the return object.
:rtype: Python object specified by `registry` registry (:class:`Registry`): A registry specifying the type of the return object
Returns: A Python object specified by `registry`
Raises:
Exception: Raises an Exception if an error occurred when building from registry
""" """
config_ = config.copy() # keep the original config untouched config_ = config.copy() # keep the original config untouched
assert isinstance( assert isinstance(
...@@ -60,11 +66,13 @@ def build_from_registry(config, registry: Registry): ...@@ -60,11 +66,13 @@ def build_from_registry(config, registry: Registry):
def build_layer(config): def build_layer(config):
"""Returns a layer object of :class:`nn.Module` constructed from `config`. """Returns a layer object of :class:`nn.Module` constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:return: An object of :class:`torch.nn.Module` used in the construction of the ``LAYERS``.
:rtype: :class:`torch.nn.Module`
Returns:
An object of :class:`torch.nn.Module`
""" """
return build_from_registry(config, LAYERS) return build_from_registry(config, LAYERS)
...@@ -73,11 +81,13 @@ def build_loss(config): ...@@ -73,11 +81,13 @@ def build_loss(config):
"""Returns a loss function object of :class:`torch.autograd.Function` constructed """Returns a loss function object of :class:`torch.autograd.Function` constructed
from `config`. from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:return: An object of :class:`torch.nn.modules.loss._Loss` used in the construction of the ``LOSSES``.
:rtype: :class:`torch.nn.modules.loss._Loss`
Returns:
An object of :class:`torch.nn.modules.loss._Loss`
""" """
return build_from_registry(config, LOSSES) return build_from_registry(config, LOSSES)
...@@ -85,11 +95,13 @@ def build_loss(config): ...@@ -85,11 +95,13 @@ def build_loss(config):
def build_model(config): def build_model(config):
"""Returns a model object of :class:`nn.Module` constructed from `config`. """Returns a model object of :class:`nn.Module` constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:return: An object of :class:`torch.nn.Module` used in the construction of the ``MODELS``.
:rtype: :class:`torch.nn.Module`
Returns:
An object of :class:`torch.nn.Module`
""" """
return build_from_registry(config, MODELS) return build_from_registry(config, MODELS)
...@@ -98,11 +110,13 @@ def build_dataset(config): ...@@ -98,11 +110,13 @@ def build_dataset(config):
"""Returns a dataset object of :class:`torch.utils.data.Dataset` constructed """Returns a dataset object of :class:`torch.utils.data.Dataset` constructed
from `config`. from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:return: An object of :class:`torch.utils.data.Dataset` used in the construction of the ``DATASETS``.
:rtype: :class:`torch.utils.data.Dataset`
Returns:
An object of :class:`torch.utils.data.Dataset`
""" """
return build_from_registry(config, DATASETS) return build_from_registry(config, DATASETS)
...@@ -111,13 +125,14 @@ def build_optimizer(config, model): ...@@ -111,13 +125,14 @@ def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, """Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'. 'model' and 'params'.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:param model: A model containing parameters for the optimizer used in the construction of the ``OPTIMIZERS``.
:type model: :class:`nn.Module` model (:class:`nn.Module`): A model containing parameters for the optimizer
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer` Returns:
An object of :class:`torch.optim.Optimizer`
""" """
config_ = config.copy() config_ = config.copy()
config_['params'] = model.parameters() config_['params'] = model.parameters()
...@@ -128,15 +143,15 @@ def build_gradient_handler(config, model, optimizer): ...@@ -128,15 +143,15 @@ def build_gradient_handler(config, model, optimizer):
"""Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`, """Returns a gradient handler object of :class:`BaseGradientHandler` constructed from `config`,
`model` and `optimizer`. `model` and `optimizer`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:param model: A model containing parameters for the gradient handler used in the construction of the ``GRADIENT_HANDLER``.
:type model: :class:`nn.Module` model (:class:`nn.Module`): A model containing parameters for the gradient handler
:param optimizer: An optimizer object containing parameters for the gradient handler optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing parameters for the gradient handler
:type optimizer: :class:`torch.optim.Optimizer`
:return: An object of :class:`colossalai.engine.BaseGradientHandler` Returns:
:rtype: :class:`colossalai.engine.BaseGradientHandler` An object of :class:`colossalai.engine.BaseGradientHandler`
""" """
config_ = config.copy() config_ = config.copy()
config_['model'] = model config_['model'] = model
...@@ -147,13 +162,13 @@ def build_gradient_handler(config, model, optimizer): ...@@ -147,13 +162,13 @@ def build_gradient_handler(config, model, optimizer):
def build_hooks(config, trainer): def build_hooks(config, trainer):
"""Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`. """Returns a hook object of :class:`BaseHook` constructed from `config` and `trainer`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:param trainer: A :class:`Trainer` object containing parameters for the hook used in the construction of the ``HOOKS``.
:type trainer: :class:`Trainer`
:return: An object of :class:`colossalai.trainer.hooks.BaseHook` Returns:
:rtype: :class:`colossalai.trainer.hooks.BaseHook` An object of :class:`colossalai.trainer.hooks.BaseHook`
""" """
config_ = config.copy() config_ = config.copy()
config_['trainer'] = trainer config_['trainer'] = trainer
...@@ -163,11 +178,13 @@ def build_hooks(config, trainer): ...@@ -163,11 +178,13 @@ def build_hooks(config, trainer):
def build_ophooks(config): def build_ophooks(config):
"""Returns a hook object of :class:`BaseOpHook` constructed from `config`. """Returns a hook object of :class:`BaseOpHook` constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:return: An object of :class:`colossalai.trainer.hooks.BaseOpHook` used in the construction of the ``OPHOOKS``.
:rtype: :class:`colossalai.trainer.hooks.BaseOpHook`
Returns:
An object of :class:`colossalai.trainer.hooks.BaseOpHook`
""" """
config_ = config.copy() config_ = config.copy()
return build_from_registry(config_, OPHOOKS) return build_from_registry(config_, OPHOOKS)
...@@ -177,11 +194,13 @@ def build_transform(config): ...@@ -177,11 +194,13 @@ def build_transform(config):
"""Returns a transformation object of :class:`torchvision.transforms` constructed """Returns a transformation object of :class:`torchvision.transforms` constructed
from `config`. from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:return: An object of :class:`torchvision.transforms` used in the construction of the ``TRANSFORMS``.
:rtype: :class:`torchvision.transforms`
Returns:
An object of :class:`torchvision.transforms`
""" """
return build_from_registry(config, TRANSFORMS) return build_from_registry(config, TRANSFORMS)
...@@ -190,14 +209,15 @@ def build_data_sampler(config, dataset): ...@@ -190,14 +209,15 @@ def build_data_sampler(config, dataset):
"""Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler` """Returns a data sampler object of :class:`colossalai.nn.data.sampler.BaseSampler`
constructed from `config`. constructed from `config`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:param dataset: An object of :class:`torch.utils.data.Dataset` containing information used in the construction of the ``DATA_SAMPLERS``.
dataset (:class:`torch.utils.data.Dataset`): An object of
:class:`torch.utils.data.Dataset` containing information
used in the construction of the return object used in the construction of the return object
:type dataset: :class:`torch.utils.data.Dataset` Returns:
:return: An object of :class:`colossalai.utils.data_sampler.BaseSampler` An object of :class:`colossalai.utils.data_sampler.BaseSampler`
:rtype: :class:`colossalai.utils.data_sampler.BaseSampler`
""" """
config_ = config.copy() config_ = config.copy()
config_['dataset'] = dataset config_['dataset'] = dataset
...@@ -208,14 +228,15 @@ def build_lr_scheduler(config, optimizer): ...@@ -208,14 +228,15 @@ def build_lr_scheduler(config, optimizer):
"""Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler` """Returns a learning rate scheduler object of :class:`torch.optim.lr_scheduler`
constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`. constructed from `config`, `optimizer`, `total_steps` and `num_steps_per_epoch`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:param optimizer: An optimizer object containing parameters for the learning rate used in the construction of the ``lr_schedule``.
scheduler optimizer (:class:`torch.optim.Optimizer`): An optimizer object containing
:type optimizer: :class:`torch.optim.Optimizer` parameters for the learning rate scheduler.
:return: An object of :class:`torch.optim.lr_scheduler`
:rtype: :class:`torch.optim.lr_scheduler` Returns:
An object of :class:`torch.optim.lr_scheduler`
""" """
config_ = config.copy() config_ = config.copy()
config_['optimizer'] = optimizer config_['optimizer'] = optimizer
...@@ -225,10 +246,12 @@ def build_lr_scheduler(config, optimizer): ...@@ -225,10 +246,12 @@ def build_lr_scheduler(config, optimizer):
def build_schedule(config): def build_schedule(config):
"""Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`. """Returns a schedule of :class:`colossalai.engine.schedule.BaseSchedule`.
:param config: A python dict or a :class:`colossalai.context.Config` object Args:
containing information used in the construction of the return object config (dict or :class:`colossalai.context.Config`): A python dict or
:type config: dict or :class:`colossalai.context.Config` a :class:`colossalai.context.Config` object containing information
:return: An object of :class:`colossalai.engine.schedule.BaseSchedule` used in the construction of the ``Schedule``.
:rtype: :class:`colossalai.engine.schedule.BaseSchedule`
Returns:
An object of :class:`colossalai.engine.schedule.BaseSchedule`
""" """
return build_from_registry(config, SCHEDULE) return build_from_registry(config, SCHEDULE)
...@@ -13,14 +13,13 @@ def _binary_partition(weights, st, ed): ...@@ -13,14 +13,13 @@ def _binary_partition(weights, st, ed):
"""Returns the binary partition position of `weights`, given the start """Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`. position `st` and the end position `ed`.
:param weights: A python list to be binary partitioned Args:
:type weights: list weights (list): A python list to be binary partitioned
:param st: the start position of the binary partition st (int): the start position of the binary partition
:type st: int ed (int): the end position of the binary partition
:param ed: the end postition of the binary partition
:type ed: int Returns:
:return: the binary partition position of `weights` int: the binary partition position of `weights`
:rtype: int
""" """
w_sum = weights[ed - 1] w_sum = weights[ed - 1]
prefix = 0 prefix = 0
...@@ -176,16 +175,13 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method: ...@@ -176,16 +175,13 @@ def build_pipeline_model_from_cfg(config, num_chunks: int = 1, partition_method:
... ...
) )
:param config: Configuration of the model Args:
:type config: dict config (dict): Configuration of the model.
:param num_chunks: The number of chunks you want to have on the current stage. This value should be 1 num_chunks (int, optional): The number of chunks you want to have on the current stage.
in most cases unless you are using virutal pipeline parallelism. This value should be 1 in most cases unless you are using virtual pipeline parallelism.
:type num_chunks: int, optional partition_method (str, optional): This parameter determines how you want to split your model
:param partition_method: This parameter determines how you want to split your model layers into stages, layers into stages, you can set it as 'layer' or 'parameter'.
you can set it as 'layer' or 'parameter' verbose (bool, optional): Whether to print the logs.
:type partition_method: str, optional
:param verbose: Whether to print the logs
:type verbose: bool, optional
""" """
ori_model = build_model(config) ori_model = build_model(config)
layers = ori_model.layers_cfg layers = ori_model.layers_cfg
...@@ -240,13 +236,11 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo ...@@ -240,13 +236,11 @@ def build_pipeline_model(layers: nn.Sequential, num_chunks: int = 1, verbose: bo
"""An intializer to split the model into different stages for pipeline parallelism. """An intializer to split the model into different stages for pipeline parallelism.
Note that `layer` must be `torch.nn.Sequential`. Note that `layer` must be `torch.nn.Sequential`.
:param layers: Layers of model Args:
:type layers: `torch.nn.Sequential` layers (`torch.nn.Sequential`): Layers of model
:param num_chunks: The number of chunks you want to have on the current stage. This value should be 1 num_chunks: The number of chunks you want to have on the current stage. This value should be 1
in most cases unless you are using virutal pipeline parallelism. in most cases unless you are using virtual pipeline parallelism.
:type num_chunks: int, optional verbose (bool, optional): Whether to print the logs.
:param verbose: Whether to print the logs
:type verbose: bool, optional
""" """
pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE) pipeline_parallel_size = gpc.get_world_size(ParallelMode.PIPELINE)
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
......
...@@ -12,21 +12,22 @@ from colossalai.utils import get_current_device ...@@ -12,21 +12,22 @@ from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor: def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a r"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension. specific dimension.
:param tensor: Tensor to be gathered Note:
:param dim: The dimension concatenating in The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
:param parallel_mode: Parallel group mode used in this communication in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
:param async_op: Whether operations are asynchronous
:type tensor: :class:`torch.Tensor` Args:
:type dim: int tensor (:class:`torch.Tensor`): Tensor to be gathered.
:type parallel_mode: :class:`colossalai.context.ParallelMode` dim (int): The dimension concatenating in.
:type async_op: bool, optional parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous.
:return: The tensor generated by all-gather Returns:
:rtype: :class:`torch.Tensor` Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of all-together only,
if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True.
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
...@@ -54,23 +55,26 @@ def reduce_scatter(tensor: Tensor, ...@@ -54,23 +55,26 @@ def reduce_scatter(tensor: Tensor,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all r"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group. members in the parallel group.
:param tensor: Tensor to be reduced and scattered Note:
:param dim: The dimension scattering in The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
:param parallel_mode: Parallel group mode used in this communication in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
:param op: The type of reduce operation
:param async_op: Whether operations are asynchronous Args:
tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered.
:type tensor: :class:`torch.Tensor` dim (int): The dimension concatenating in.
:type dim: int parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
:type parallel_mode: :class:`colossalai.context.ParallelMode` op (torch.distributed.ReduceOp, optional): The type of reduce operation,
:type op: ReduceOp, optional should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
:type async_op: bool, optional More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
:return: The tensor generated by reduce-scatter async_op (bool, optional): Whether operations are asynchronous.
:rtype: :class:`Tensor`
Returns:
Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of reduce_scatter only,
if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True.
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
...@@ -94,6 +98,25 @@ def all_reduce(tensor: Tensor, ...@@ -94,6 +98,25 @@ def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode, parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM, op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
r"""Reduces the tensor data across whole parallel group in such a way that all get the final result.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
Args:
tensor (:class:`torch.Tensor`): Tensor to be all-reduced.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
op (torch.distributed.ReduceOp, optional): The type of reduce operation,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to
`ReduceOp <https://pytorch.org/docs/stable/distributed.html#torch.distributed.ReduceOp>`_.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of all-gather only,
if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True.
"""
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor out = tensor
...@@ -108,6 +131,23 @@ def all_reduce(tensor: Tensor, ...@@ -108,6 +131,23 @@ def all_reduce(tensor: Tensor,
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
r"""Broadcast tensors to whole parallel group. Tensor must have the same
number of elements in all processes participating in the collective.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
Args:
tensor (:class:`torch.Tensor`): Tensor to be broadcast.
src (int): Source rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The tensor need to be broadcast only,
if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True.
"""
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor out = tensor
...@@ -122,6 +162,23 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b ...@@ -122,6 +162,23 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
r"""Reduce tensors across whole parallel group. Only the process with
rank ``dst`` is going to receive the final result.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
Args:
tensor (:class:`torch.Tensor`): Tensor to be reduced.
dst (int): Destination rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous.
Returns:
Union[tuple(:class:`torch.Tensor`, work handle), :class:`torch.Tensor`]: The result of reduce only,
if async_op is set to False. A tuple of output of all-gather and Async work handle, if async_op is set to True.
"""
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor out = tensor
......
...@@ -19,12 +19,12 @@ TensorShape = Union[torch.Size, List[int], Tuple[int]] ...@@ -19,12 +19,12 @@ TensorShape = Union[torch.Size, List[int], Tuple[int]]
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]: def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
"""get the exact tensor shape when communicating and return whether the tensor is a chunk """get the exact tensor shape when communicating and return whether the tensor is a chunk
:param tensor_shape: shape of tensor Args:
:type tensor_shape: TensorShape tensor_shape (:class:`torch.Size`): shape of tensor
:param chunk_tensor: whether to chunk tensor, defaults to False chunk_tensor (bool, optional): whether to chunk tensor, defaults to False
:type chunk_tensor: bool, optional
:return: exact tensor shape, whether to chunk tensor Returns:
:rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool] Tuple[Union[torch.Size, List[int], Tuple[int]], bool]: exact tensor shape, whether to chunk tensor
""" """
if chunk_tensor: if chunk_tensor:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1) tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
...@@ -134,14 +134,14 @@ def _communicate(tensor_send_next=None, ...@@ -134,14 +134,14 @@ def _communicate(tensor_send_next=None,
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False): def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the input tensor from the previous member in pipeline. """Copy the forward output from the previous stage in pipeline as the input tensor of this stage.
:param input_tensor_shape: The shape of the tensor to be recieved Args:
:param prev_rank: The rank of the source of the tensor input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
:type input_tensor_shape: torch.Size prev_rank (int, optional): The rank of the source of the tensor.
:type prev_rank: int, optional
:return: The input tensor in forward step Returns:
:rtype: :class:`torch.Tensor` :class:`torch.Tensor`: The input tensor.
""" """
if gpc.is_pipeline_first_stage(): if gpc.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -155,14 +155,14 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_ ...@@ -155,14 +155,14 @@ def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False): def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the grad tensor from the next member in pipeline. """Copy the gradient tensor from the next stage in pipeline as the input gradient of this stage.
:param output_grad_shape: The shape of the tensor to be recieved Args:
:param next_rank: The rank of the source of the tensor output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
:type output_grad_shape: torch.Size next_rank (int, optional): The rank of the source of the tensor.
:type next_rank: int, optional
:return: The grad of output tensor in forward step Returns:
:rtype: :class:`torch.Tensor` :class:`torch.Tensor`: The input gradient tensor.
""" """
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -176,12 +176,11 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_ ...@@ -176,12 +176,11 @@ def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False): def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
"""Sends the input tensor to the next member in pipeline. """Sends the input tensor to the next stage in pipeline.
:param output_tensor: Tensor to be sent Args:
:param next_rank: The rank of the recipient of the tensor output_tensor (:class:`torch.Tensor`): Tensor to be sent.
:type output_tensor: :class:`torch.Tensor` next_rank (int, optional): The rank of the recipient of the tensor.
:type next_rank: int, optional
""" """
if not gpc.is_pipeline_last_stage(): if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor, _communicate(tensor_send_next=output_tensor,
...@@ -190,12 +189,11 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False): ...@@ -190,12 +189,11 @@ def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False): def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
"""Sends the grad tensor to the previous member in pipeline. """Sends the gradient tensor to the previous stage in pipeline.
:param input_tensor_grad: Tensor to be sent Args:
:param prev_rank: The rank of the recipient of the tensor input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent
:type input_tensor_grad: :class:`torch.Tensor` prev_rank (int, optional): The rank of the recipient of the tensor
:type prev_rank: int, optional
""" """
if not gpc.is_pipeline_first_stage(): if not gpc.is_pipeline_first_stage():
_communicate(tensor_send_prev=input_tensor_grad, _communicate(tensor_send_prev=input_tensor_grad,
...@@ -210,15 +208,15 @@ def send_forward_recv_backward(output_tensor, ...@@ -210,15 +208,15 @@ def send_forward_recv_backward(output_tensor,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the next stage in pipeline, while receives the gradient tensor from the
next member in pipeline. next stage in pipeline as the input gradient tensor of this stage.
:param output_tensor: Tensor to be sent Args:
:param output_grad_shape: The shape of the tensor to be recieved output_tensor (:class:`torch.Tensor`): Tensor to be sent.
:type output_tensor: :class:`torch.Tensor` output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
:type output_grad_shape: :class:`torch.Size`
:return: The grad of output tensor in forward step Returns:
:rtype: :class:`torch.Tensor` :class:`torch.Tensor`: The input gradient tensor.
""" """
if gpc.is_pipeline_last_stage(): if gpc.is_pipeline_last_stage():
output_tensor_grad = None output_tensor_grad = None
...@@ -238,16 +236,16 @@ def send_backward_recv_forward(input_tensor_grad, ...@@ -238,16 +236,16 @@ def send_backward_recv_forward(input_tensor_grad,
prev_rank=None, prev_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the gradient tensor to the
previous member in pipeline, while recieves the input tensor from the previous stage in pipeline, while receives the output tensor from the
previous member in pipeline. previous stage in pipeline as the input of this stage.
:param input_tensor_grad: Tensor to be sent Args:
:param input_tensor_shape: The shape of the tensor to be recieved input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
:type input_tensor_grad: :class:`torch.Tensor` input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
:type input_tensor_shape: :class:`torch.Size`
:return: The input tensor in forward step Returns:
:rtype: :class:`torch.Tensor` :class:`torch.Tensor`: The input tensor.
""" """
if gpc.is_pipeline_first_stage(): if gpc.is_pipeline_first_stage():
input_tensor = None input_tensor = None
...@@ -269,15 +267,15 @@ def send_forward_recv_forward(output_tensor, ...@@ -269,15 +267,15 @@ def send_forward_recv_forward(output_tensor,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the next stage in pipeline, while receives the output tensor from the
previous member in pipeline. previous stage in pipeline as the input of this stage.
:param output_tensor: Tensor to be sent Args:
:param input_tensor_shape: The shape of the tensor to be recieved output_tensor (:class:`torch.Tensor`): Tensor to be sent.
:type output_tensor: :class:`torch.Tensor` input_tensor_shape (:class:`torch.Size`): The shape of the tensor to be received.
:type input_tensor_shape: :class:`torch.Size`
:return: The input tensor in forward step Returns:
:rtype: :class:`torch.Tensor` :class:`torch.Tensor`: The input tensor.
""" """
input_tensor, _ = _communicate(tensor_send_next=output_tensor, input_tensor, _ = _communicate(tensor_send_next=output_tensor,
recv_prev=recv_prev, recv_prev=recv_prev,
...@@ -296,16 +294,16 @@ def send_backward_recv_backward(input_tensor_grad, ...@@ -296,16 +294,16 @@ def send_backward_recv_backward(input_tensor_grad,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the gradient tensor to the
previous member in pipeline, while recieves the grad tensor from the previous stage in pipeline, while receives the gradient tensor from the
next member in pipeline. next member in pipeline as the input of this stage.
:param input_tensor_grad: Tensor to be sent Args:
:param output_grad_shape: The shape of the tensor to be recieved input_tensor_grad (:class:`torch.Tensor`): Tensor to be sent.
:type input_tensor_grad: :class:`torch.Tensor` output_grad_shape (:class:`torch.Size`): The shape of the tensor to be received.
:type output_grad_shape: :class:`torch.Size`
:return: The grad of output tensor in forward step Returns:
:rtype: :class:`torch.Tensor` :class:`torch.Tensor`: The input gradient tensor.
""" """
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad, _, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
recv_next=recv_next, recv_next=recv_next,
...@@ -327,20 +325,18 @@ def send_forward_backward_recv_forward_backward(output_tensor, ...@@ -327,20 +325,18 @@ def send_forward_backward_recv_forward_backward(output_tensor,
next_rank=None, next_rank=None,
dtype=torch.float, dtype=torch.float,
scatter_gather_tensors=False): scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the next and """Batched communication operation. Sends the input tensor to the next stage in pipeline and
the grad tensor to the previous, while recieves the grad tensor from the the gradient tensor to the previous stage, while receives the input gradient tensor from the
next and the input tensor from the previous. next stage and the input tensor from the previous stage.
:param output_tensor: Tensor sent to the next Args:
:param input_tensor_grad: Tensor sent to the previous output_tensor (:class:`torch.Tensor`): Tensor sent to the next.
:param input_tensor_shape: The shape of the tensor recieved from the previous input_tensor_grad (:class:`torch.Tensor`): Tensor sent to the previous.
:param output_grad_shape: The shape of the tensor recieved from the next input_tensor_shape (:class:`torch.Size`): The shape of the tensor received from the previous.
:type output_tensor: :class:`torch.Tensor` output_grad_shape (:class:`torch.Size`): The shape of the tensor received from the next.
:type input_tensor_grad: :class:`torch.Tensor`
:type input_tensor_shape: :class:`torch.Size` Returns:
:type output_grad_shape: :class:`torch.Size` Tuple(Tensor, Tensor): (the input tensor, the input gradient tensor)
:return: (the input tensor in forward step, the grad of output tensor in forward step)
:rtype: (Tensor, Tensor)
""" """
input_tensor, output_tensor_grad = _communicate( input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor, tensor_send_next=output_tensor,
......
...@@ -9,15 +9,19 @@ from colossalai.utils import get_current_device, synchronize ...@@ -9,15 +9,19 @@ from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode): def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
"""Sends a tensor to the next member and recieves a tensor from the previous member. """Sends a tensor to the next member and receives a tensor from the previous member.
This function returns the recieved tensor from the previous member. This function returns the received tensor from the previous member.
:param tensor_send_next: Tensor sent to next member Args:
:param parallel_mode: Parallel group mode used in this communication tensor_send_next: Tensor sent to next member
:type tensor_send_next: :class:`torch.Tensor` parallel_mode: Parallel group mode used in this communication
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:return: The tensor recieved from the previous Returns:
:rtype: :class:`torch.Tensor` :class:`torch.Tensor`: The tensor received from the previous.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
buffer_shape = tensor_send_next.size() buffer_shape = tensor_send_next.size()
......
...@@ -12,14 +12,13 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None): ...@@ -12,14 +12,13 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
meta information of the tensor should be sent before communications. This function meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`. synchronizes with :func:`recv_tensor_meta`.
:param tensor: Tensor to be sent Args:
:param need_meta: If False, meta information won't be sent tensor (torch.Tensor): Tensor to be sent.
:param next_rank: The rank of the next member in pipeline parallel group need_meta (bool, optional): If False, meta information won't be sent.
:type tensor: Tensor next_rank (int): The rank of the next member in pipeline parallel group.
:type need_meta: bool, optional
:type next_rank: int Returns:
:return: False bool: False
:rtype: bool
""" """
if need_meta: if need_meta:
if next_rank is None: if next_rank is None:
...@@ -36,17 +35,17 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None): ...@@ -36,17 +35,17 @@ def send_tensor_meta(tensor, need_meta=True, next_rank=None):
def recv_tensor_meta(tensor_shape, prev_rank=None): def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor. """Receives tensor meta information before receiving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function meta information of the tensor should be received before communications. This function
synchronizes with :func:`send_tensor_meta`. synchronizes with :func:`send_tensor_meta`.
:param tensor_shape: The shape of the tensor to be recieved Args:
:param prev_rank: The rank of the source of the tensor tensor_shape (torch.Size): The shape of the tensor to be received.
:type tensor_shape: torch.Size prev_rank (int): The rank of the source of the tensor.
:type prev_rank: int, optional
:return: The shape of the tensor to be recieved Returns:
:rtype: torch.Size torch.Size: The shape of the tensor to be received.
""" """
if tensor_shape is None: if tensor_shape is None:
if prev_rank is None: if prev_rank is None:
...@@ -67,14 +66,12 @@ def recv_tensor_meta(tensor_shape, prev_rank=None): ...@@ -67,14 +66,12 @@ def recv_tensor_meta(tensor_shape, prev_rank=None):
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks. """Break a tensor into equal 1D chunks.
:param tensor: Tensor to be splitted before communication Args:
:param new_buffer: Whether uses a new buffer to store sliced tensor tensor (torch.Tensor): Tensor to be split before communication.
new_buffer (bool, optional): Whether to use a new buffer to store sliced tensor.
:type tensor: torch.Tensor Returns:
:type new_buffer: bool, optional torch.Tensor: The split tensor
:return splitted_tensor: The splitted tensor
:rtype splitted_tensor: torch.Tensor
""" """
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D) partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D) start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
...@@ -92,11 +89,10 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False): ...@@ -92,11 +89,10 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def gather_split_1d_tensor(tensor): def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks. """Opposite of above function, gather values from model parallel ranks.
:param tensor: Tensor to be gathered after communication Args:
:type tensor: torch.Tensor tensor (torch.Tensor): Tensor to be gathered after communication.
Returns:
:return gathered: The gathered tensor gathered (torch.Tensor): The gathered tensor
:rtype gathered: torch.Tensor
""" """
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D) world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
numel = torch.numel(tensor) numel = torch.numel(tensor)
......
...@@ -12,8 +12,8 @@ class Config(dict): ...@@ -12,8 +12,8 @@ class Config(dict):
"""This is a wrapper class for dict objects so that values of which can be """This is a wrapper class for dict objects so that values of which can be
accessed as attributes. accessed as attributes.
:param config: The dict object to be wrapped Args:
:type config: dict config (dict): The dict object to be wrapped.
""" """
def __init__(self, config: dict = None): def __init__(self, config: dict = None):
...@@ -50,12 +50,14 @@ class Config(dict): ...@@ -50,12 +50,14 @@ class Config(dict):
def from_file(filename: str): def from_file(filename: str):
"""Reads a python file and constructs a corresponding :class:`Config` object. """Reads a python file and constructs a corresponding :class:`Config` object.
:param filename: Name of the file to construct the return object Args:
:type filename: str filename (str): Name of the file to construct the return object.
:raises AssertionError: Raises an AssertionError if the file does not exist, or the file
is not .py file Returns:
:return: A :class:`Config` object constructed with information in the file :class:`Config`: A :class:`Config` object constructed with information in the file.
:rtype: :class:`Config`
Raises:
AssertionError: Raises an AssertionError if the file does not exist, or the file is not .py file
""" """
# check config path # check config path
......
...@@ -22,6 +22,10 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -22,6 +22,10 @@ class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context, """This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device. such as the global rank, the local rank, the world size, etc. of each device.
Note:
The parallel_mode used in this class should be concluded in ``ParallelMode``.
More details about ``ParallelMode`` could be found in
`parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
def __init__(self): def __init__(self):
...@@ -62,10 +66,12 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -62,10 +66,12 @@ class ParallelContext(metaclass=SingletonMeta):
def load_config(self, config: Union[dict, str]): def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file. """Loads the configuration from either a dict or a file.
:param config: Either a dict containing the configuration information or the filename Args:
of a file containing the configuration information config (dict or str): Either a dict containing the configuration information or the filename
:type config: dict or str of a file containing the configuration information.
:raises TypeError: Raises a TypeError if `config` is neither a dict or a str
Raises:
TypeError: Raises a TypeError if `config` is neither a dict nor a str.
""" """
if isinstance(config, str): if isinstance(config, str):
self._config = Config.from_file(config) self._config = Config.from_file(config)
...@@ -81,20 +87,21 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -81,20 +87,21 @@ class ParallelContext(metaclass=SingletonMeta):
def get_global_rank(self): def get_global_rank(self):
"""Returns the global rank of the current device. """Returns the global rank of the current device.
:return: The global rank of the current device Returns:
:rtype: int int: The global rank of the current device
""" """
return self._global_ranks[ParallelMode.GLOBAL] return self._global_ranks[ParallelMode.GLOBAL]
def add_global_rank(self, parallel_mode: ParallelMode, rank: int): def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the global rank of the current device for `parallel_mode` to the context. """Adds the global rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank.
:param rank: The rank to be added rank (int): The rank to be added
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance Raises:
of :class:`colossalai.context.ParallelMode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._global_ranks[parallel_mode] = rank self._global_ranks[parallel_mode] = rank
...@@ -102,12 +109,15 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -102,12 +109,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_local_rank(self, parallel_mode: ParallelMode): def get_local_rank(self, parallel_mode: ParallelMode):
"""Returns the local rank of the current device. """Returns the local rank of the current device.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: The local rank of the current device for `parallel_mode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
:rtype: int of :class:`colossalai.context.ParallelMode`.
Returns:
int: The local rank of the current device for `parallel_mode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._local_ranks[parallel_mode] return self._local_ranks[parallel_mode]
...@@ -115,12 +125,13 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -115,12 +125,13 @@ class ParallelContext(metaclass=SingletonMeta):
def add_local_rank(self, parallel_mode: ParallelMode, rank: int): def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context. """Adds the local rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank.
:param rank: The rank to be added rank (int): The rank to be added.
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance Raises:
of :class:`colossalai.context.ParallelMode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank self._local_ranks[parallel_mode] = rank
...@@ -128,12 +139,15 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -128,12 +139,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_next_global_rank(self, parallel_mode: ParallelMode): def get_next_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the next device. """Returns the global rank of the next device.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: The global rank of the next device for `parallel_mode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
:rtype: int of :class:`colossalai.context.ParallelMode`.
Returns:
int: The global rank of the next device for `parallel_mode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
...@@ -147,12 +161,15 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -147,12 +161,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_prev_global_rank(self, parallel_mode: ParallelMode): def get_prev_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the previous device. """Returns the global rank of the previous device.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: The global rank of the previous device for `parallel_mode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
:rtype: int of :class:`colossalai.context.ParallelMode`.
Returns:
int: The global rank of the previous device for `parallel_mode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
...@@ -167,13 +184,16 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -167,13 +184,16 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns a boolean value indicating whether the current device is the first one """Returns a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`. among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: a boolean value indicating whether the current device is the first one AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
among its group for `parallel_mode` of :class:`colossalai.context.ParallelMode`.
:rtype: bool
Returns:
bool: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
""" """
rank = self.get_local_rank(parallel_mode) rank = self.get_local_rank(parallel_mode)
return rank == 0 return rank == 0
...@@ -182,13 +202,16 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -182,13 +202,16 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns a boolean value indicating whether the current device is the last one """Returns a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`. among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: a boolean value indicating whether the current device is the last one AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
among its group for `parallel_mode` of :class:`colossalai.context.ParallelMode`.
:rtype: bool
Returns:
bool: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
""" """
rank = self.get_local_rank(parallel_mode) rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode) world_size = self.get_world_size(parallel_mode)
...@@ -210,12 +233,15 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -210,12 +233,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_world_size(self, parallel_mode: ParallelMode): def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`. """Returns the world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: The world size for `parallel_mode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
:rtype: int of :class:`colossalai.context.ParallelMode`.
Returns:
int: The world size for `parallel_mode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._world_sizes[parallel_mode] return self._world_sizes[parallel_mode]
...@@ -223,12 +249,13 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -223,12 +249,13 @@ class ParallelContext(metaclass=SingletonMeta):
def add_world_size(self, parallel_mode: ParallelMode, world_size: int): def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`. """Adds world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:param world_size: The world size to be added world_size (int): The world size to be added
:type world_size: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance Raises:
of :class:`colossalai.context.ParallelMode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._world_sizes[parallel_mode] = world_size self._world_sizes[parallel_mode] = world_size
...@@ -236,12 +263,15 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -236,12 +263,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_group(self, parallel_mode: ParallelMode): def get_group(self, parallel_mode: ParallelMode):
"""Returns the group of the current device for `parallel_mode`. """Returns the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: The group of the current device for `parallel_mode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
:rtype: torch.distributed.ProcessGroup of :class:`colossalai.context.ParallelMode`.
Returns:
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._groups[parallel_mode] return self._groups[parallel_mode]
...@@ -249,12 +279,13 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -249,12 +279,13 @@ class ParallelContext(metaclass=SingletonMeta):
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup): def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`. """Adds the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:param group: The group to be added group (torch.distributed.ProcessGroup): The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance Raises:
of :class:`colossalai.context.ParallelMode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group self._groups[parallel_mode] = group
...@@ -262,12 +293,15 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -262,12 +293,15 @@ class ParallelContext(metaclass=SingletonMeta):
def get_ranks_in_group(self, parallel_mode: ParallelMode): def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group. """Returns the rank of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` Raises:
:return: the rank of the current device for `parallel_mode` in the group AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
:rtype: int of :class:`colossalai.context.ParallelMode`.
Returns:
int: The rank of the current device for `parallel_mode` in the group.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
return self._ranks_in_group[parallel_mode] return self._ranks_in_group[parallel_mode]
...@@ -275,28 +309,26 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -275,28 +309,26 @@ class ParallelContext(metaclass=SingletonMeta):
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list): def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group. """Adds the ranks of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:param ranks: List of ranks to be added ranks (list): List of ranks to be added
:type ranks: list
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance Raises:
of :class:`colossalai.context.ParallelMode` AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int): def init_global_dist(self, rank: int, world_size: int, backend: str, host: str, port: int):
"""Initializes the global distributed environment """Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int Args:
:param world_size: world size of the default process group rank (int): rank for the default process group.
:type world_size: int world_size (int): world size of the default process group.
:param host: the master address for distributed training backend (str): backend for ``torch.distributed``
:type host: str host (str): the master address for distributed training.
:param port: the master port for distributed training port (str): the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
""" """
# initialize the default process group # initialize the default process group
init_method = f'tcp://{host}:{port}' init_method = f'tcp://{host}:{port}'
...@@ -315,8 +347,9 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -315,8 +347,9 @@ class ParallelContext(metaclass=SingletonMeta):
def check_sanity(self): def check_sanity(self):
"""Checks sanity of the parallel context. """Checks sanity of the parallel context.
:raises AssertionError: Raises an AssertionError if the world size does not equal to the product Raises:
of data paralle size, pipeline parallel size and tensor parallel size AssertionError: Raises an AssertionError if the world size does not equal to the product
of data parallel size, pipeline parallel size and tensor parallel size.
""" """
dps = self.data_parallel_size dps = self.data_parallel_size
pps = self.pipeline_parallel_size pps = self.pipeline_parallel_size
...@@ -341,7 +374,8 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -341,7 +374,8 @@ class ParallelContext(metaclass=SingletonMeta):
def init_parallel_groups(self): def init_parallel_groups(self):
"""Initializes the parallel groups. """Initializes the parallel groups.
:raises AssertionError: Raises an AssertionError if the field paralle is not present in the config file Raises:
AssertionError: Raises an AssertionError if the field parallel is not present in the config file.
""" """
# get rank and world size # get rank and world size
...@@ -411,11 +445,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -411,11 +445,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns a boolean value indicating whether `parallel_mode` is initialized """Returns a boolean value indicating whether `parallel_mode` is initialized
in the current system. in the current system.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:return: a boolean value indicating whether `parallel_mode` is initialized
in the current system Returns:
:rtype: bool bool: a boolean value indicating whether `parallel_mode` is initialized in the current system.
""" """
return parallel_mode in self._groups return parallel_mode in self._groups
...@@ -432,8 +466,8 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -432,8 +466,8 @@ class ParallelContext(metaclass=SingletonMeta):
def set_device(self, device_ordinal: int = None): def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices. """Sets distributed processes to be bound to devices.
:param device_ordinal: the device id to be bound to Args:
:type device_ordinal: int, optional device_ordinal (int, optional): the device id to be bound to
""" """
global_rank = self.get_global_rank() global_rank = self.get_global_rank()
if device_ordinal is None: if device_ordinal is None:
...@@ -447,8 +481,8 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -447,8 +481,8 @@ class ParallelContext(metaclass=SingletonMeta):
def set_seed(self, seed: int): def set_seed(self, seed: int):
"""Sets seeds for all random libraries. """Sets seeds for all random libraries.
:param seed: seed for random states Args:
:type seed: int seed (int): seed for random states
""" """
random.seed(seed) random.seed(seed)
np.random.seed(seed) np.random.seed(seed)
......
...@@ -11,8 +11,16 @@ from .process_group_initializer import ProcessGroupInitializer ...@@ -11,8 +11,16 @@ from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer): class Initializer_1D(ProcessGroupInitializer):
'''A ProcessGroupInitializer for 1d tensor parallelism. """A ProcessGroupInitializer for 1d tensor parallelism.
'''
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -20,8 +28,10 @@ class Initializer_1D(ProcessGroupInitializer): ...@@ -20,8 +28,10 @@ class Initializer_1D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. """Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple Returns:
Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
1D tensor parallelism's information in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
......
...@@ -22,12 +22,16 @@ def _check_summa_env_var(summa_dim): ...@@ -22,12 +22,16 @@ def _check_summa_env_var(summa_dim):
class Initializer_2D_Row(ProcessGroupInitializer): class Initializer_2D_Row(ProcessGroupInitializer):
"""2d tensor parallel initialization among rows. """2d tensor parallel initialization among rows.
:param num_group: The number of all tensor groups
:param summa_dim: The dimension of SUMMA Args:
:param args: Args used to initialize base class num_group (int): The number of all tensor groups.
:param kwargs: Kwargs used to initialize base class summa_dim (int): The dimension of SUMMA.
:type num_group: int rank (int): The rank of current process.
:type summa_dim: int world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, num_group, summa_dim, *args, **kwargs): def __init__(self, num_group, summa_dim, *args, **kwargs):
...@@ -37,9 +41,9 @@ class Initializer_2D_Row(ProcessGroupInitializer): ...@@ -37,9 +41,9 @@ class Initializer_2D_Row(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
Returns:
:return: 2D tensor row parallelism's information Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) 2D tensor row parallelism's information in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -64,13 +68,15 @@ class Initializer_2D_Row(ProcessGroupInitializer): ...@@ -64,13 +68,15 @@ class Initializer_2D_Row(ProcessGroupInitializer):
class Initializer_2D_Col(ProcessGroupInitializer): class Initializer_2D_Col(ProcessGroupInitializer):
"""2d tensor parallel initialization among cols. """2d tensor parallel initialization among cols.
:param num_group: The number of all tensor groups Args:
:param summa_dim: The dimension of SUMMA num_group (int): The number of all tensor groups.
:param args: Args used to initialize base class summa_dim (int): The dimension of SUMMA.
:param kwargs: Kwargs used to initialize base class rank (int): The rank of current process.
world_size (int): Size of whole communication world.
:type num_group: int config (Config): Running configuration.
:type summa_dim: int data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, num_group, summa_dim, *args, **kwargs): def __init__(self, num_group, summa_dim, *args, **kwargs):
...@@ -81,8 +87,9 @@ class Initializer_2D_Col(ProcessGroupInitializer): ...@@ -81,8 +87,9 @@ class Initializer_2D_Col(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor col parallelism's information Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2D tensor col parallelism's information in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -109,8 +116,13 @@ class Initializer_2D(ProcessGroupInitializer): ...@@ -109,8 +116,13 @@ class Initializer_2D(ProcessGroupInitializer):
""" """
Serve as the single entry point to 2D parallel initialization. Serve as the single entry point to 2D parallel initialization.
:param args: Args used to initialize ProcessGroupInitializer Args:
:param kwargs: Kwargs used to initialize ProcessGroupInitializer rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -127,8 +139,10 @@ class Initializer_2D(ProcessGroupInitializer): ...@@ -127,8 +139,10 @@ class Initializer_2D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
2D tensor parallelism's information in a list of tuples.
""" """
parallel_setting = [self.row_initializer.init_dist_group(), self.col_initializer.init_dist_group()] parallel_setting = [self.row_initializer.init_dist_group(), self.col_initializer.init_dist_group()]
return parallel_setting return parallel_setting
...@@ -31,14 +31,17 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int): ...@@ -31,14 +31,17 @@ def _check_tesseract_env_var(tesseract_dim: int, tesseract_dep: int):
# i row j col k dep # i row j col k dep
class Initializer_2p5D_ROW(ProcessGroupInitializer): class Initializer_2p5D_ROW(ProcessGroupInitializer):
"""2p5d tensor parallel initialization among rows. """2.5d tensor parallel initialization among rows.
:param tesseract_dim: The dimension of tesseract Args:
:param tesseract_dep: The dimension of depth tesseract_dim (int): The dimension of tesseract.
:param args: Args used to initialize base class tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
:type tesseract_dim: int world_size (int): Size of whole communication world.
:type tesseract_dep: int config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
...@@ -50,10 +53,11 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): ...@@ -50,10 +53,11 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2p5D tensor row parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor row parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor row parallelism's information Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor row parallelism's information in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -80,14 +84,17 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer): ...@@ -80,14 +84,17 @@ class Initializer_2p5D_ROW(ProcessGroupInitializer):
class Initializer_2p5D_Col(ProcessGroupInitializer): class Initializer_2p5D_Col(ProcessGroupInitializer):
"""2p5d tensor parallel initialization among cols. """2.5d tensor parallel initialization among cols.
:param tesseract_dim: The dimension of tesseract Args:
:param tesseract_dep: The dimension of depth tesseract_dim (int): The dimension of tesseract.
:param args: Args used to initialize base class tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
:type tesseract_dim: int world_size (int): Size of whole communication world.
:type tesseract_dep: int config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
...@@ -99,10 +106,11 @@ class Initializer_2p5D_Col(ProcessGroupInitializer): ...@@ -99,10 +106,11 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2p5D tensor col parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor col parallelism's information Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor col parallelism's information in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -129,14 +137,17 @@ class Initializer_2p5D_Col(ProcessGroupInitializer): ...@@ -129,14 +137,17 @@ class Initializer_2p5D_Col(ProcessGroupInitializer):
class Initializer_2p5D_Dep(ProcessGroupInitializer): class Initializer_2p5D_Dep(ProcessGroupInitializer):
"""2p5D tensor parallel initialization among depths. """2.5D tensor parallel initialization among depths.
:param tesseract_dim: The dimension of tesseract Args:
:param tesseract_dep: The dimension of depth tesseract_dim (int): The dimension of tesseract.
:param args: Args used to initialize base class tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
:type tesseract_dim: int world_size (int): Size of whole communication world.
:type tesseract_dep: int config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
...@@ -148,10 +159,11 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer): ...@@ -148,10 +159,11 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2p5D tensor depth parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor depth parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor depth parallelism's information Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor depth parallelism's information in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -179,14 +191,17 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer): ...@@ -179,14 +191,17 @@ class Initializer_2p5D_Dep(ProcessGroupInitializer):
# i row j col k dep # i row j col k dep
class Initializer_2p5D_XZ(ProcessGroupInitializer): class Initializer_2p5D_XZ(ProcessGroupInitializer):
"""2p5d tensor parallel initialization among cols times dep. """2.5d tensor parallel initialization among cols times dep.
:param tesseract_dim: The dimension of tesseract Args:
:param tesseract_dep: The dimension of depth tesseract_dim (int): The dimension of tesseract.
:param args: Args used to initialize base class tesseract_dep (int): The dimension of depth.
rank (int): The rank of current process.
:type tesseract_dim: int world_size (int): Size of whole communication world.
:type tesseract_dep: int config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, tesseract_dim: int, tesseract_dep: int, *args): def __init__(self, tesseract_dim: int, tesseract_dep: int, *args):
...@@ -198,10 +213,11 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer): ...@@ -198,10 +213,11 @@ class Initializer_2p5D_XZ(ProcessGroupInitializer):
"Tensor parallel size should be depth * dim ** 2 in 2.5D parallel" "Tensor parallel size should be depth * dim ** 2 in 2.5D parallel"
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2p5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: 2p5D tensor colXdepth parallelism's information Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
2.5D tensor colXdepth parallelism's information in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -232,20 +248,14 @@ class Initializer_2p5D(ProcessGroupInitializer): ...@@ -232,20 +248,14 @@ class Initializer_2p5D(ProcessGroupInitializer):
""" """
Serve as the single entry point to Tesseract parallel initialization. Serve as the single entry point to Tesseract parallel initialization.
:param rank: The rank of current process Args:
:param world_size: Size of whole communication world rank (int): The rank of current process.
:param config: Running configuration world_size (int): Size of whole communication world.
:param data_parallel_size: Size of data parallel config (Config): Running configuration.
:param pipeline_parallel_size: Size of pipeline parallel data_parallel_size (int): Size of data parallel.
:param tensor_parallel_size: Size of tensor parallel pipeline_parallel_size (int): Size of pipeline parallel.
:param depth: The depth of 2p5d parallel tensor_parallel_size (int): Size of tensor parallel.
:type rank: int depth (int): The depth of 2.5d parallel.
:type world_size: int
:type config: Config
:type data_parallel_size: int
:type pipeline_parallel_size: int
:type tensor_parallel_size: int
:type depth: int
""" """
def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int, def __init__(self, rank: int, world_size: int, config: Config, data_parallel_size: int, pipeline_parallel_size: int,
...@@ -266,9 +276,11 @@ class Initializer_2p5D(ProcessGroupInitializer): ...@@ -266,9 +276,11 @@ class Initializer_2p5D(ProcessGroupInitializer):
self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args) self.xz_initializer = Initializer_2p5D_XZ(self.tesseract_dim, self.tesseract_dep, *args)
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2.5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: Whole 2p5D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
Whole 2.5D tensor parallelism's information in a list of tuples.
""" """
parallel_setting = [ parallel_setting = [
self.col_initializer.init_dist_group(), self.col_initializer.init_dist_group(),
......
...@@ -26,12 +26,15 @@ def _check_depth_env_var(depth): ...@@ -26,12 +26,15 @@ def _check_depth_env_var(depth):
class Initializer_3D_Input(ProcessGroupInitializer): class Initializer_3D_Input(ProcessGroupInitializer):
"""3D tensor parallel initialization among input. """3D tensor parallel initialization among input.
:param num_group: The number of all tensor groups Args:
:param depth: Depth of 3D parallelism num_group (int): The number of all tensor groups.
:param args: Args used in base class depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
:type num_group: int world_size (int): Size of whole communication world.
:type depth: int config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, num_group: int, depth: int, *args): def __init__(self, num_group: int, depth: int, *args):
...@@ -42,8 +45,9 @@ class Initializer_3D_Input(ProcessGroupInitializer): ...@@ -42,8 +45,9 @@ class Initializer_3D_Input(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu. """Initialize 3D tensor parallel groups among input, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among input Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among input in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -70,12 +74,15 @@ class Initializer_3D_Input(ProcessGroupInitializer): ...@@ -70,12 +74,15 @@ class Initializer_3D_Input(ProcessGroupInitializer):
class Initializer_3D_Weight(ProcessGroupInitializer): class Initializer_3D_Weight(ProcessGroupInitializer):
"""3D tensor parallel initialization among weight. """3D tensor parallel initialization among weight.
:param num_group: The number of all tensor groups Args:
:param depth: Depth of 3D parallelism num_group (int): The number of all tensor groups.
:param args: Args used in base class depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
:type num_group: int world_size (int): Size of whole communication world.
:type depth: int config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, num_group: int, depth: int, *args): def __init__(self, num_group: int, depth: int, *args):
...@@ -86,8 +93,9 @@ class Initializer_3D_Weight(ProcessGroupInitializer): ...@@ -86,8 +93,9 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu. """Initialize 3D tensor parallel groups among weight, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among weight Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among weight in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -114,12 +122,15 @@ class Initializer_3D_Weight(ProcessGroupInitializer): ...@@ -114,12 +122,15 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
class Initializer_3D_Output(ProcessGroupInitializer): class Initializer_3D_Output(ProcessGroupInitializer):
"""3D tensor parallel initialization among output. """3D tensor parallel initialization among output.
:param num_group: The number of all tensor groups Args:
:param depth: Depth of 3D parallelism num_group (int): The number of all tensor groups.
:param args: Args used in base class depth (int): Depth of 3D parallelism.
rank (int): The rank of current process.
:type num_group: int world_size (int): Size of whole communication world.
:type depth: int config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, num_group: int, depth: int, *args): def __init__(self, num_group: int, depth: int, *args):
...@@ -130,8 +141,9 @@ class Initializer_3D_Output(ProcessGroupInitializer): ...@@ -130,8 +141,9 @@ class Initializer_3D_Output(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu. """Initialize 3D tensor parallel groups among output, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information among output Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
3D tensor parallelism's information among output in a tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -158,7 +170,14 @@ class Initializer_3D_Output(ProcessGroupInitializer): ...@@ -158,7 +170,14 @@ class Initializer_3D_Output(ProcessGroupInitializer):
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
class Initializer_3D(ProcessGroupInitializer): class Initializer_3D(ProcessGroupInitializer):
"""Serve as the single entry point to 3D parallel initialization. """Serve as the single entry point to 3D parallel initialization.
:param args: Args used to initialize ProcessGroupInitializer
Args:
rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, *args): def __init__(self, *args):
...@@ -175,8 +194,10 @@ class Initializer_3D(ProcessGroupInitializer): ...@@ -175,8 +194,10 @@ class Initializer_3D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) Returns:
List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
Whole 3D tensor parallelism's information in a list of tuples.
""" """
parallel_setting = [ parallel_setting = [
self.input_initializer.init_dist_group(), self.input_initializer.init_dist_group(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment