Commit ec5086c4 authored by Liang Bowen's avatar Liang Bowen Committed by アマデウス
Browse files

Refactored docstring to google style

parent 53b1b6e3
...@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode ...@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode
class Initializer_Data(ProcessGroupInitializer): class Initializer_Data(ProcessGroupInitializer):
"""A ProcessGroupInitializer for data parallelism. """A ProcessGroupInitializer for data parallelism.
:param args: Args used to initialize ProcessGroupInitializer Args:
:param kwargs: Kwargs used to initialize ProcessGroupInitializer rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -22,8 +27,9 @@ class Initializer_Data(ProcessGroupInitializer): ...@@ -22,8 +27,9 @@ class Initializer_Data(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize data parallel groups, and assign local_ranks and groups to each gpu. """Initialize data parallel groups, and assign local_ranks and groups to each gpu.
:return: Data parallelism's information Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Data parallelism's information tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
......
...@@ -12,8 +12,13 @@ class Initializer_Model(ProcessGroupInitializer): ...@@ -12,8 +12,13 @@ class Initializer_Model(ProcessGroupInitializer):
"""A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel """A ProcessGroupInitializer for model parallelism (model parallel group contains pipeline and tensor parallel
groups). groups).
:param args: Args used to initialize ProcessGroupInitializer Args:
:param kwargs: Kwargs used to initialize ProcessGroupInitializer rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -24,8 +29,9 @@ class Initializer_Model(ProcessGroupInitializer): ...@@ -24,8 +29,9 @@ class Initializer_Model(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize model parallel groups, and assign local_ranks and groups to each gpu. """Initialize model parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode) Returns:
:rtype: Tuple Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Model parallelism's information tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
......
...@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode ...@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode
class Initializer_Pipeline(ProcessGroupInitializer): class Initializer_Pipeline(ProcessGroupInitializer):
"""A ProcessGroupInitializer for pipeline parallelism. """A ProcessGroupInitializer for pipeline parallelism.
:param args: Args used to initialize ProcessGroupInitializer Args:
:param kwargs: Kwargs used to initialize ProcessGroupInitializer rank (int): The rank of current process
world_size (int): Size of whole communication world
config (Config): Running configuration
data_parallel_size (int): Size of data parallel
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -23,8 +28,9 @@ class Initializer_Pipeline(ProcessGroupInitializer): ...@@ -23,8 +28,9 @@ class Initializer_Pipeline(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu. """Initialize pipeline parallel groups, and assign local_ranks and groups to each gpu.
:return: Pipeline parallelism's information Returns:
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
A Pipeline parallelism's information in list of tuples.
""" """
dist_settings = list() dist_settings = list()
for i in range(self.data_parallel_size): for i in range(self.data_parallel_size):
......
...@@ -15,8 +15,13 @@ class Initializer_Sequence_DP(ProcessGroupInitializer): ...@@ -15,8 +15,13 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
In Sequence Parallelism, each GPU holds the full copy of model weights, In Sequence Parallelism, each GPU holds the full copy of model weights,
thus, gradient all-reduce occurs across all processes in the same pipeline stage thus, gradient all-reduce occurs across all processes in the same pipeline stage
:param args: Args used to initialize ProcessGroupInitializer Args:
:param kwargs: Kwargs used to initialize ProcessGroupInitializer rank (int): The rank of current process
world_size (int): Size of whole communication world
config (Config): Running configuration
data_parallel_size (int): Size of data parallel
pipeline_parallel_size (int): Size of pipeline parallel
tensor_parallel_size (int): Size of tensor parallel
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
...@@ -27,8 +32,8 @@ class Initializer_Sequence_DP(ProcessGroupInitializer): ...@@ -27,8 +32,8 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize Sequence Parallel process groups used for gradient all-reduce. """Initialize Sequence Parallel process groups used for gradient all-reduce.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode) Returns:
:rtype: Tuple Tuple: A tuple (local_rank, group_world_size, process_group, ranks_in_group, mode).
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
...@@ -52,8 +57,13 @@ class Initializer_Sequence_DP(ProcessGroupInitializer): ...@@ -52,8 +57,13 @@ class Initializer_Sequence_DP(ProcessGroupInitializer):
class Initializer_Sequence(ProcessGroupInitializer): class Initializer_Sequence(ProcessGroupInitializer):
"""A ProcessGroupInitializer for sequence parallelism. """A ProcessGroupInitializer for sequence parallelism.
:param args: Args used to initialize ProcessGroupInitializer Args:
:param kwargs: Kwargs used to initialize ProcessGroupInitializer rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, def __init__(self,
*args, **kwargs): *args, **kwargs):
...@@ -66,11 +76,12 @@ class Initializer_Sequence(ProcessGroupInitializer): ...@@ -66,11 +76,12 @@ class Initializer_Sequence(ProcessGroupInitializer):
"""Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu. """Initialize Sequence parallel process groups and assign local_ranks and groups to each gpu.
Sequence parallelism requires 2 process groups. The first is for model forward where several processes Sequence parallelism requires 2 process groups. The first is for model forward where several processes
exchange paritial query, key and value embedding to compute self attention values. The second is for exchange partial query, key and value embedding to compute self attention values. The second is for
all-reduce to synchronize the model parameters. all-reduce to synchronize the model parameters.
:return: Sequence parallelism's information Returns:
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) List[Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode)]:
A Sequence parallelism's information in list of tuples.
""" """
parallel_setting = [] parallel_setting = []
......
...@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode ...@@ -12,8 +12,13 @@ from ..parallel_mode import ParallelMode
class Initializer_Tensor(ProcessGroupInitializer): class Initializer_Tensor(ProcessGroupInitializer):
"""A ProcessGroupInitializer for tensor parallelism. """A ProcessGroupInitializer for tensor parallelism.
:param args: Args used to initialize ProcessGroupInitializer Args:
:param kwargs: Kwargs used to initialize ProcessGroupInitializer rank (int): The rank of current process.
world_size (int): Size of whole communication world.
config (Config): Running configuration.
data_parallel_size (int): Size of data parallel.
pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -22,8 +27,9 @@ class Initializer_Tensor(ProcessGroupInitializer): ...@@ -22,8 +27,9 @@ class Initializer_Tensor(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize tensor parallel groups, and assign local_ranks and groups to each gpu. """Initialize tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: Tensor parallelism's information Returns:
:rtype: Tuple(local_rank, group_world_size, process_group, ranks_in_group, mode) Tuple (local_rank, group_world_size, process_group, ranks_in_group, mode):
A Tensor parallelism's information tuple.
""" """
local_rank = None local_rank = None
ranks_in_group = None ranks_in_group = None
......
...@@ -9,19 +9,13 @@ from colossalai.context import Config ...@@ -9,19 +9,13 @@ from colossalai.context import Config
class ProcessGroupInitializer(ABC): class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups. """An object, knowing the parallelism configuration, that initializes parallel groups.
:param rank: The rank of current process Args:
:param world_size: Size of whole communication world rank (int): The rank of current process.
:param config: Running configuration world_size (int): Size of whole communication world.
:param data_parallel_size: Size of data parallel config (Config): Running configuration.
:param pipeline_parallel_size: Size of pipeline parallel data_parallel_size (int): Size of data parallel.
:param tensor_parallel_size: Size of tensor parallel pipeline_parallel_size (int): Size of pipeline parallel.
tensor_parallel_size (int): Size of tensor parallel.
:type rank: int
:type world_size: int
:type config: Config
:type data_parallel_size: int
:type pipeline_parallel_size: int
:type tensor_parallel_size: int
""" """
def __init__(self, def __init__(self,
rank: int, rank: int,
......
...@@ -16,8 +16,8 @@ _SEED_MANAGER = SeedManager() ...@@ -16,8 +16,8 @@ _SEED_MANAGER = SeedManager()
def get_seeds(): def get_seeds():
"""Returns the seeds of the seed manager. """Returns the seeds of the seed manager.
:return: The seeds of the seed manager Returns:
:rtype: dict dict: The seeds of the seed manager.
""" """
return _SEED_MANAGER.seeds return _SEED_MANAGER.seeds
...@@ -25,8 +25,8 @@ def get_seeds(): ...@@ -25,8 +25,8 @@ def get_seeds():
def get_states(copy=False): def get_states(copy=False):
"""Returns the seed states of the seed manager. """Returns the seed states of the seed manager.
:return: The seed states of the seed manager Returns:
:rtype: dict dict: The seed states of the seed manager.
""" """
states = _SEED_MANAGER.seed_states states = _SEED_MANAGER.seed_states
...@@ -43,8 +43,8 @@ def get_states(copy=False): ...@@ -43,8 +43,8 @@ def get_states(copy=False):
def get_current_mode(): def get_current_mode():
"""Returns the current mode of the seed manager. """Returns the current mode of the seed manager.
:return: The current mode of the seed manager. Returns:
:rtype: :class:`torch.ByteTensor` :class:`torch.ByteTensor`: The current mode of the seed manager.
""" """
return _SEED_MANAGER.current_mode return _SEED_MANAGER.current_mode
...@@ -52,12 +52,16 @@ def get_current_mode(): ...@@ -52,12 +52,16 @@ def get_current_mode():
def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`. """Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:param seed: The seed to be added seed (int): The seed to be added
:type seed: int Raises:
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
_SEED_MANAGER.add_seed(parallel_mode, seed, overwrite) _SEED_MANAGER.add_seed(parallel_mode, seed, overwrite)
...@@ -65,8 +69,12 @@ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False): ...@@ -65,8 +69,12 @@ def add_seed(parallel_mode: ParallelMode, seed: int, overwrite: bool = False):
def set_mode(parallel_mode: ParallelMode): def set_mode(parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager. """Sets the current mode of the seed manager.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
_SEED_MANAGER.set_mode(parallel_mode) _SEED_MANAGER.set_mode(parallel_mode)
...@@ -74,11 +82,12 @@ def set_mode(parallel_mode: ParallelMode): ...@@ -74,11 +82,12 @@ def set_mode(parallel_mode: ParallelMode):
def set_seed_states(parallel_mode: ParallelMode, state: Tensor): def set_seed_states(parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`. """Sets the state of the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:param state: the state to be set state (:class:`torch.Tensor`): the state to be set.
:type state: :class:`torch.Tensor`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
""" """
_SEED_MANAGER.set_state(parallel_mode, state) _SEED_MANAGER.set_state(parallel_mode, state)
...@@ -98,6 +107,9 @@ def seed(parallel_mode: ParallelMode): ...@@ -98,6 +107,9 @@ def seed(parallel_mode: ParallelMode):
with seed(ParallelMode.DATA): with seed(ParallelMode.DATA):
output = F.dropout(input) output = F.dropout(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
try: try:
# set to new mode # set to new mode
...@@ -125,6 +137,9 @@ def with_seed(func, parallel_mode: ParallelMode): ...@@ -125,6 +137,9 @@ def with_seed(func, parallel_mode: ParallelMode):
wrapper_forward = with_seed(forward, ParallelMode.DATA) wrapper_forward = with_seed(forward, ParallelMode.DATA)
out = wrapped_forward(input) out = wrapped_forward(input)
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
@functools.wraps(func) @functools.wraps(func)
......
...@@ -9,6 +9,10 @@ from colossalai.context.parallel_mode import ParallelMode ...@@ -9,6 +9,10 @@ from colossalai.context.parallel_mode import ParallelMode
class SeedManager: class SeedManager:
"""This class is a manager of all random seeds involved in the system. """This class is a manager of all random seeds involved in the system.
Note:
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
def __init__(self): def __init__(self):
...@@ -30,12 +34,12 @@ class SeedManager: ...@@ -30,12 +34,12 @@ class SeedManager:
def set_state(self, parallel_mode: ParallelMode, state: Tensor): def set_state(self, parallel_mode: ParallelMode, state: Tensor):
"""Sets the state of the seed manager for `parallel_mode`. """Sets the state of the seed manager for `parallel_mode`.
Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
state (:class:`torch.Tensor`): the state to be set.
:param parallel_mode: The chosen parallel mode Raises:
:type parallel_mode: :class:`colossalai.context.ParallelMode` AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager.
:param state: the state to be set
:type state: :class:`torch.Tensor`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not found in the seed manager
""" """
assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager' assert parallel_mode in self._seed_states, f'Parallel mode {parallel_mode} is not found in the seed manager'
self._seed_states[parallel_mode] = state self._seed_states[parallel_mode] = state
...@@ -43,8 +47,8 @@ class SeedManager: ...@@ -43,8 +47,8 @@ class SeedManager:
def set_mode(self, parallel_mode: ParallelMode): def set_mode(self, parallel_mode: ParallelMode):
"""Sets the current mode of the seed manager. """Sets the current mode of the seed manager.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
""" """
if self.current_mode: if self.current_mode:
# save the current state for current mode # save the current state for current mode
...@@ -57,14 +61,14 @@ class SeedManager: ...@@ -57,14 +61,14 @@ class SeedManager:
def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False): def add_seed(self, parallel_mode: ParallelMode, seed: int, overwrtie: bool = False):
"""Adds a seed to the seed manager for `parallel_mode`. """Adds a seed to the seed manager for `parallel_mode`.
:param parallel_mode: The chosen parallel mode Args:
:type parallel_mode: :class:`colossalai.context.ParallelMode` parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode.
:param seed: The seed to be added seed (int): The seed to be added.
:type seed: int overwrtie (bool, optional): Whether allows to overwrite the seed that has been set already
:param overwrtie: Whether allows to overwrite the seed that has been set already
:type overwrtie: bool, optional Raises
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of AssertionError: Raises an AssertionError if `parallel_mode` is not an instance of
:class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added :class:`colossalai.context.ParallelMode` or the seed for `parallel_mode` has been added.
""" """
assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided' assert isinstance(parallel_mode, ParallelMode), 'A valid ParallelMode must be provided'
if overwrtie is False: if overwrtie is False:
......
...@@ -19,20 +19,37 @@ class Engine: ...@@ -19,20 +19,37 @@ class Engine:
:meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset. :meth:`step` which is based on the given :attr:`schedule` over each batch of a dataset.
It controls a iteration in training. It controls a iteration in training.
:param model: The neural network model Args:
:type model: ``torch.nn.Module`` model (``torch.nn.Module``): The neural network model.
:param optimizer: Optimizer for updating the parameters optimizer (``torch.optim.Optimizer``): Optimizer for updating the parameters.
:type optimizer: ``torch.optim.Optimizer`` criterion (``torch.nn.modules.loss._Loss``, optional): Loss function for calculating loss.
:param criterion: Loss function for calculating loss gradient_handlers (List[``BaseGradientHandler``], optional): A list of gradient handler used in backward.
:type criterion: ``torch.nn.modules.loss._Loss``, optional clip_grad_norm (float, optional): The norm of gradient clipping.
:param gradient_handlers: A list of gradient handler used in backward ophook_list (list): List of ophook.
:type gradient_handlers: a list of ``BaseGradientHandler``, optional verbose (bool): whether to display log info.
:param clip_grad_norm: The norm of gradient clipping
:type clip_grad_norm: float, optional Examples:
:param ophook_list: List of ophook >>> # define model, criterion, optimizer, lr_scheduler, train_dataloader for your training
:type ophook_list: list >>> model = ...
:param verbose: whether to display log info >>> criterion = ...
:type verbose: bool >>> optimizer = ...
>>> train_dataloader = ...
>>> engine, _, _, _ = colossalai.initialize(model, optimizer, criterion)
>>> engine.train()
>>> for inputs, labels in train_dataloader
>>> # set gradients to zero
>>> engine.zero_grad()
>>> # run forward pass
>>> outputs = engine(inputs)
>>> # compute loss value and run backward pass
>>> loss = engine.criterion(outputs, labels)
>>> engine.backward(loss)
>>> # update parameters
>>> engine.step()
The example of using Engine in training could be find in
`Training with engine and trainer <https://www.colossalai.org/docs/basics/engine_trainer>`_. and
`Run resnet cifar10 with engine <https://github.com/hpcaitech/ColossalAI-Examples/blob/main/image/resnet/run_resnet_cifar10_with_engine.py>`_.
""" """
def __init__(self, def __init__(self,
...@@ -113,10 +130,10 @@ class Engine: ...@@ -113,10 +130,10 @@ class Engine:
return self.optimizer.step() return self.optimizer.step()
def backward(self, loss: Tensor): def backward(self, loss: Tensor):
"""Start backward propagation given the loss value computed by a loss function """Start backward propagation given the loss value computed by a loss function.
:param loss: Loss value computed by a loss function Args:
:type loss: :class:`torch.Tensor` loss (:class:`torch.Tensor`): Loss value computed by a loss function.
""" """
ret = self.optimizer.backward(loss) ret = self.optimizer.backward(loss)
for ophook in self._ophook_list: for ophook in self._ophook_list:
...@@ -124,34 +141,22 @@ class Engine: ...@@ -124,34 +141,22 @@ class Engine:
return ret return ret
def backward_by_grad(self, tensor, grad): def backward_by_grad(self, tensor, grad):
"""Start backward propagation given the gradient of the output tensor """Start backward propagation given the gradient of the output tensor.
:param tensor: Output tensor Args:
:type tensor: :class:`torch.Tensor` tensor (:class:`torch.Tensor`): Output tensor.
:param grad: Gradient passed back to the output grad (:class:`torch.Tensor`): Gradient passed back to the output.
:type grad: :class:`torch.Tensor`
""" """
ret = self.optimizer.backward_by_grad(tensor, grad) ret = self.optimizer.backward_by_grad(tensor, grad)
for ophook in self._ophook_list: for ophook in self._ophook_list:
ophook.post_iter() ophook.post_iter()
return ret return ret
def calc_loss(self, *args, **kwargs):
"""Compute the loss value
:param args: Args used in criterion function
:param kwargs: Kwargs used in criterion function
:return: The loss value
:rtype: :class:`torch.Tensor`
"""
return self.criterion(*args, **kwargs)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
"""Run the forward step for the model """Run the forward step for the model.
:return: Output the model Returns:
:rtype: Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor` Tuple[:class:`torch.Tensor`] or :class:`torch.Tensor`: Output of the model.
""" """
return self.model(*args, **kwargs) return self.model(*args, **kwargs)
......
...@@ -8,10 +8,9 @@ class BaseGradientHandler(ABC): ...@@ -8,10 +8,9 @@ class BaseGradientHandler(ABC):
"""A basic helper class to handle all-reduce operations of gradients across different parallel groups """A basic helper class to handle all-reduce operations of gradients across different parallel groups
before optimization. before optimization.
:param model: Model where the gradients accumulate Args:
:param optimizer: Optimizer for updating the parameters model (Module): Model where the gradients accumulate.
:type model: Module optimizer (Optimizer): Optimizer for updating the parameters.
:type optimizer: Optimizer
""" """
def __init__(self, model, optimizer): def __init__(self, model, optimizer):
self._model = model self._model = model
......
...@@ -17,12 +17,11 @@ import math ...@@ -17,12 +17,11 @@ import math
class MemTracerOpHook(BaseOpHook): class MemTracerOpHook(BaseOpHook):
""" """
Collect GPU memory usage information Collect GPU memory usage information
:param warmup: This parameter indicates how many iterations to truncate before profiling, defaults to 50
:type warmup: int Args:
:param refreshrate: This parameter decides the frequency of write file, defaults to 10 warmup (int): This parameter indicates how many iterations to truncate before profiling, defaults to 50.
:type refreshrate: int refreshrate (int): This parameter decides the frequency of write file, defaults to 10.
:param data_prefix: The prefix of the stats data file, defaults to "memstats" data_prefix (string): The prefix of the stats data file, defaults to "memstats".
:type data_prefix: string
""" """
def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"): def __init__(self, warmup: int = 50, refreshrate: int = 10, data_prefix: str = "memstats"):
......
...@@ -15,8 +15,12 @@ class BaseSchedule(ABC): ...@@ -15,8 +15,12 @@ class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation. """A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and It mainly composes of forward_backward_step for gradient backward and
optimizer_step for parameters update. optimizer_step for parameters update.
For the convenience to enable FP16, we aggreate all codes that contain the For the convenience to enable FP16, we aggregate all codes that contain the
control of FP16 in class schedule. control of FP16 in class schedule.
Args:
batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data,
and it will be executed in load_batch.
""" """
def __init__(self, batch_data_process_func: Callable = None): def __init__(self, batch_data_process_func: Callable = None):
...@@ -46,13 +50,12 @@ class BaseSchedule(ABC): ...@@ -46,13 +50,12 @@ class BaseSchedule(ABC):
"""Loads a batch from data iterator. It returns the data and labels which are """Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's. already in the same GPU as where the model's.
:param data_iter: Data iterator from which get a batch of data Args:
:type data_iter: DataIter data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
:param to_gpu: Whether the data should be moved to GPU to_gpu (bool, optional): Whether the data should be moved to GPU
:type to_gpu: bool, optional
:return: (data, label) Returns:
:rtype: (:class:`Tensor`, :class:`torch.Tensor`) Tuple (:class:`Tensor`, :class:`torch.Tensor`): A tuple of (data, label).
""" """
if data_iter is None: if data_iter is None:
raise RuntimeError('Dataloader is not defined.') raise RuntimeError('Dataloader is not defined.')
...@@ -87,16 +90,12 @@ class BaseSchedule(ABC): ...@@ -87,16 +90,12 @@ class BaseSchedule(ABC):
): ):
"""The process function over a batch of dataset for training or evaluation. """The process function over a batch of dataset for training or evaluation.
:param engine: Colossalai training engine Args:
:type engine: colossalai.engine.Engine engine (colossalai.engine.Engine): Colossalai engine for training and inference.
:param data_iter: Data iterator from which get a batch of data data_iter (Iterable): Data iterator from which get a batch of data, obtained by calling iter(dataloader).
:type data_iter: DataIter forward_only (bool): If True, the process won't include backward.
:param forward_only: If True, the process won't include backward return_loss (bool, optional): If False, the loss won't be returned.
:type forward_only: bool return_output_label (bool, optional): If False, the output and label won't be returned.
:param return_loss: If False, the loss won't be returned
:type return_loss: bool, optional
:param return_output_label: If False, the output and label won't be returned
:type return_output_label: bool, optional
""" """
pass pass
......
...@@ -15,6 +15,10 @@ class NonPipelineSchedule(BaseSchedule): ...@@ -15,6 +15,10 @@ class NonPipelineSchedule(BaseSchedule):
During one process, it loads a batch of dataset and feeds it to the model. During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step` After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode. to update the parameters if it is in training mode.
Args:
batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data,
and it will be executed in load_batch.
""" """
def forward_backward_step(self, def forward_backward_step(self,
...@@ -23,22 +27,19 @@ class NonPipelineSchedule(BaseSchedule): ...@@ -23,22 +27,19 @@ class NonPipelineSchedule(BaseSchedule):
forward_only: bool = False, forward_only: bool = False,
return_loss: bool = True, return_loss: bool = True,
return_output_label: bool = True): return_output_label: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model. """The process function that loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False. The returned labels and loss will None if :attr:`return_loss` is False.
:param engine: Model for training and inference Args:
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader) engine (colossalai.engine.Engine): Colossalai engine for training and inference.
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
:param return_loss: Loss will be returned if True forward_only (bool, optional):
:param return_output_label: Output and label will be returned if True If True, the model is run for the forward pass, else back propagation will be executed.
:type engine: Iterator return_loss (bool, optional): Loss will be returned if True.
:type data_iter: Iterator return_output_label (bool, optional): Output and label will be returned if True.
:type forward_only: bool, optional
:type return_loss: bool, optional
:type return_output_label: bool, optional
:return: (output, label, loss) Returns:
:rtype: Tuple[:class:`torch.Tensor`] Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
""" """
assert forward_only or return_loss, \ assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False." "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
......
...@@ -41,14 +41,13 @@ class PipelineSchedule(BaseSchedule): ...@@ -41,14 +41,13 @@ class PipelineSchedule(BaseSchedule):
It uses non-interleaved 1F1B strategy. Other properties are similar as It uses non-interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`. :class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches Args:
:type num_microbatches: int num_microbatches (int): The number of microbatches.
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch` batch_data_process_func (Callable, optional):
:type batch_data_process_func: Callable, optional The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
:param tensor_shape: Specified shape in pipeline communication tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
:type tensor_shape: torch.Size, optional scatter_gather_tensors (bool, optional):
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
:type scatter_gather_tensors: bool, optional
""" """
def __init__(self, def __init__(self,
...@@ -131,19 +130,14 @@ class PipelineSchedule(BaseSchedule): ...@@ -131,19 +130,14 @@ class PipelineSchedule(BaseSchedule):
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
:param engine: Your engine object Args:
:type engine: colossalai.engine.Engine engine (colossalai.engine.Engine): Colossalai engine for training and inference.
:param input_tensor: Input tensor for this pipeline stage input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage.
:type input_tensor: :class:`torch.Tensor` return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
:param return_tensors: A list of tensors to return return_output_label (bool, optional): Whether returns output labels.
:type return_tensors: List[:class:`torch.Tensor`] accum_loss (optional): Where accumulated loss stores.
:param return_output_label: Whether returns output labels Returns:
:type return_output_label: bool, optional :class:`torch.Tensor`: output or the loss value of the current pipeline stage.
:param accum_loss: Where accumulated loss stores
:type accum_loss: optional
:return: output or the loss value of the current pipeline stage
:rtype: :class:`torch.Tensor`
""" """
data, label = self.load_micro_batch() data, label = self.load_micro_batch()
output_tensor = self._call_engine(engine.model, input_tensor, data) output_tensor = self._call_engine(engine.model, input_tensor, data)
...@@ -173,17 +167,14 @@ class PipelineSchedule(BaseSchedule): ...@@ -173,17 +167,14 @@ class PipelineSchedule(BaseSchedule):
Returns the gradients with respect to the input tensor (None if first stage). Returns the gradients with respect to the input tensor (None if first stage).
This is a helper function and can be ignored by users. This is a helper function and can be ignored by users.
:param engine: your engine object Args:
:type engine: colossalai.engine.Engine engine (colossalai.engine.Engine): Colossalai engine for training and inference.
:param input_tensor: input tensor for this pipeline stage input_tensor (:class:`torch.Tensor`): input tensor for this pipeline stage.
:type input_tensor: :class:`torch.Tensor` output_tensor (:class:`torch.Tensor`): output tensor for this pipeline stage.
:param output_tensor: output tensor for this pipeline stage output_tensor_grad (:class:`torch.Tensor`): gradient of output tensor for this pipeline stage.
:type output_tensor: :class:`torch.Tensor`
:param output_tensor_grad: gradient of output tensor for this pipeline stage Returns:
:type output_tensor_grad: :class:`torch.Tensor` :class:`torch.Tensor`: gradient of input tensor.
:return: gradient of input tensor
:rtype: :class:`torch.Tensor`
""" """
# Retain the grad on the input_tensor. # Retain the grad on the input_tensor.
...@@ -207,19 +198,16 @@ class PipelineSchedule(BaseSchedule): ...@@ -207,19 +198,16 @@ class PipelineSchedule(BaseSchedule):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
:param engine: Your engine object Args:
:type engine: colossalai.engine.Engine engine (colossalai.engine.Engine): Colossalai engine for training and inference.
:param data_iter: Dataloader as the form of an iterator, obtained by calling iter(dataloader) data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
:type data_iter: Iterable forward_only (bool, optional):
:param forward_only: Whether run forward step only. Default is false. If true, no backward will be run. Whether run forward step only. Default is false. If true, no backward will be run.
:type forward_only: bool return_loss (bool, optional): Whether returns the loss value. Default is true.
:param return_loss: Whether returns the loss value. Default is true. return_output_label (bool, optional): If False, the output and label won't be returned.
:type return_loss: bool
:param return_output_label: If False, the output and label won't be returned Returns:
:type return_output_label: bool Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
:return: (output, label, loss)
:rtype: Tuple[:class:`torch.Tensor`]
""" """
assert forward_only or return_loss, \ assert forward_only or return_loss, \
...@@ -354,16 +342,14 @@ class InterleavedPipelineSchedule(PipelineSchedule): ...@@ -354,16 +342,14 @@ class InterleavedPipelineSchedule(PipelineSchedule):
It uses interleaved 1F1B strategy. Other properties are similar as It uses interleaved 1F1B strategy. Other properties are similar as
:class:`NonPipelineSchedule`. :class:`NonPipelineSchedule`.
:param num_microbatches: The number of microbatches Args:
:type num_microbatches: int num_microbatches (int): The number of microbatches.
:param num_model_chunks: The number of model chunks num_model_chunks (int): The number of model chunks.
:type num_model_chunks: int batch_data_process_func (Callable, optional):
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch` The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
:type batch_data_process_func: Callable, optional tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
:param tensor_shape: Specified shape in pipeline communication scatter_gather_tensors (bool, optional):
:type tensor_shape: torch.Size, optional If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool, optional
""" """
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \ assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size' 'num_microbatches must be an integer multiple of pipeline parallel world size'
...@@ -408,6 +394,16 @@ class InterleavedPipelineSchedule(PipelineSchedule): ...@@ -408,6 +394,16 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""Forward step for passed-in model. If it is the first stage, the input tensor """Forward step for passed-in model. If it is the first stage, the input tensor
is obtained from data_iterator, otherwise the passed-in input_tensor is used. is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users. Returns output tensor. This is a helper function and can be ignored by users.
Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
model_chunk_id (int): The id of model chunks.
input_tensor (:class:`torch.Tensor`): Input tensor for this pipeline stage.
return_tensors (List[:class:`torch.Tensor`]): A list of tensors to return.
return_output_label (bool, optional): Whether returns output labels.
accum_loss (optional): Where accumulated loss stores.
Returns:
:class:`torch.Tensor`: output or the loss value of the current pipeline stage.
""" """
data, label = self.load_micro_batch(model_chunk_id) data, label = self.load_micro_batch(model_chunk_id)
output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data) output_tensor = self._call_engine(engine.model[model_chunk_id], input_tensor, data)
...@@ -435,18 +431,17 @@ class InterleavedPipelineSchedule(PipelineSchedule): ...@@ -435,18 +431,17 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise. Args:
engine (colossalai.engine.Engine): Colossalai engine for training and inference.
:param engine: Your engine object data_iter (Iterable): Dataloader as the form of an iterator, obtained by calling iter(dataloader).
:type engine: colossalai.engine.Engine forward_only (bool, optional):
:param data_iter: Dataloader as the form of an iterator, obtained by calling iter(dataloader) Whether run forward step only. Default is false. If true, no backward will be run.
:type data_iter: Iterable return_loss (bool, optional): Whether returns the loss value. Default is true.
:param forward_only: Whether run forward step only. Default is false. If true, no backward will be run. return_output_label (bool, optional): If False, the output and label won't be returned.
:type forward_only: bool
:param return_loss: Whether returns the loss value. Default is true. Returns:
:type return_loss: bool Tuple[:class:`torch.Tensor`]: A tuple of (output, label, loss), loss and label could be None.
:param return_output_label: If False, the output and label won't be returned The loss would be returned only in the last stage.
:type return_output_label: bool
""" """
assert forward_only or return_loss, \ assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.' 'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
......
...@@ -37,8 +37,8 @@ def get_default_parser(): ...@@ -37,8 +37,8 @@ def get_default_parser():
"""Reads user command line and uses an argument parser to parse the input arguments. """Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed. Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: Returns the parser with the default arguments, the user may add customized arguments into this parser Returns:
:rtype: Namespace Namespace: Returns the parser with the default arguments, the user may add customized arguments into this parser.
""" """
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file') parser.add_argument('--config', type=str, help='path to the config file')
...@@ -63,26 +63,21 @@ def launch(config: Union[str, Path, Config, Dict], ...@@ -63,26 +63,21 @@ def launch(config: Union[str, Path, Config, Dict],
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input """This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions. arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
:param config: Config file or config file path are both acceptable Args:
:type config: Union[str, dict, Config] config (Union[str, dict, Config]): Config file or config file path are both acceptable
:param rank: Rank for the default process group rank (int): Rank for the default process group
:type rank: int world_size (int): World size of the default process group
:param world_size: World size of the default process group host (str): The master address for distributed training
:type world_size: int port (str): The master port for distributed training
:param host: The master address for distributed training backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
:type host: str local_rank (int, optional):
:param port: The master port for distributed training Rank for the process on the node and is used to set the default CUDA device,
:type port: str defaults to None. If local_rank = None, the default device ordinal will be calculated automatically.
:param backend: Backend for torch.distributed seed (int, optional): Specified random seed for every process. Defaults to 1024.
:type backend: str, optional verbose (bool, optional): Whether to print logs. Defaults to True.
:param local_rank: Rank for the process on the node and is used to set the default CUDA device, defaults to None.
If local_rank = None, the default device ordinal will be calculated automatically Raises:
:type local_rank: int, optional Exception: Raise exception when config type is wrong
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
:raises Exception: Raise exception when config type is wrong
""" """
gpc.verbose = verbose gpc.verbose = verbose
...@@ -126,18 +121,13 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict], ...@@ -126,18 +121,13 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables """A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM set by SLURM
:param config: Config file or config file path are both acceptable Args:
:type config: Union[str, dict, Config] config (Union[str, dict, Config]): Config file or config file path are both acceptable
:param host: The master address for distributed training host (str): The master address for distributed training
:type host: str port (str): The master port for distributed training
:param port: The master port for distributed training backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
:type port: str seed (int, optional): Specified random seed for every process. Defaults to 1024.
:param backend: Backend for torch.distributed verbose (bool, optional): Whether to print logs. Defaults to True.
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
""" """
rank = int(os.environ['SLURM_PROCID']) rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NPROCS']) world_size = int(os.environ['SLURM_NPROCS'])
...@@ -160,18 +150,13 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict], ...@@ -160,18 +150,13 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables """A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI set by OpenMPI
:param config: Config file or config file path are both acceptable Args:
:type config: Union[str, dict, Config] config (Union[str, dict, Config]): Config file or config file path are both acceptable
:param host: The master address for distributed training host (str): The master address for distributed training
:type host: str port (str): The master port for distributed training
:param port: The master port for distributed training backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
:type port: str seed (int, optional): Specified random seed for every process. Defaults to 1024.
:param backend: Backend for torch.distributed verbose (bool, optional): Whether to print logs. Defaults to True.
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
""" """
rank = int(os.environ['OMPI_COMM_WORLD_RANK']) rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK']) local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
...@@ -194,14 +179,11 @@ def launch_from_torch(config: Union[str, Path, Config, Dict], ...@@ -194,14 +179,11 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size """A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch from the environment variables set by PyTorch
:param config: Config file or config file path are both acceptable Args:
:type config: Union[str, dict, Config] config (Union[str, dict, Config]): Config file or config file path are both acceptable
:param backend: Backend for torch.distributed backend (str, optional): Backend for ``torch.distributed``, defaults to ``nccl``
:type backend: str, optional seed (int, optional): Specified random seed for every process. Defaults to 1024.
:param seed: Specified random seed for every processes verbose (bool, optional): Whether to print logs. Defaults to True.
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
""" """
rank = int(os.environ['RANK']) rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK']) local_rank = int(os.environ['LOCAL_RANK'])
...@@ -230,22 +212,20 @@ def initialize(model: nn.Module, ...@@ -230,22 +212,20 @@ def initialize(model: nn.Module,
"""Core function to wrap the essential training components with our functionality based on the config which is """Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config. loaded into gpc.config.
:param model: Your model instance or a function to build the model Args:
:type model: :class:`torch.nn.Module` or Callbale model (:class:`torch.nn.Module` or Callbale): Your model instance or a function to build the model.
:param optimizer: Your optimizer instance optimizer (:class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]`):
:type optimizer: :class:`torch.optim.optimizer.Optimizer` or :class:`Type[torch.optim.optimizer]` Your optimizer instance.
:param criterion: Your criterion instance criterion (:class:`torch.nn.modules.loss._Loss`, optional): Your criterion instance.
:type criterion: :class:`torch.nn.modules.loss._Loss`, optional train_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for training.
:param train_dataloader: Dataloader for training test_dataloader (:class:`torch.utils.data.DataLoader`, optional): Dataloader for testing.
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional lr_scheduler (:class:`torch.nn.lr_scheduler._LRScheduler`, optional): Your lr scheduler instance, optional.
:param test_dataloader: Dataloader for testing verbose (bool, optional): Whether to print logs.
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param lr_scheduler: Your lr scheduler instance, optional Returns:
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional Tuple (engine, train_dataloader, test_dataloader, lr_scheduler):
:param verbose: Whether to print logs A tuple of ``(engine, train_dataloader, test_dataloader, lr_scheduler)``
:type verbose: bool, optional where only ``engine`` could not be None.
:return: (engine, train_dataloader, test_dataloader, lr_scheduler)
:rtype: Tuple
""" """
# get logger # get logger
logger = get_dist_logger() logger = get_dist_logger()
......
...@@ -10,6 +10,8 @@ def get_dist_logger(name='colossalai'): ...@@ -10,6 +10,8 @@ def get_dist_logger(name='colossalai'):
"""Get logger instance based on name. The DistributedLogger will create singleton instances, """Get logger instance based on name. The DistributedLogger will create singleton instances,
which means that only one logger instance is created per name. which means that only one logger instance is created per name.
Args:
:param name: name of the logger, name must be unique :param name: name of the logger, name must be unique
:type name: str :type name: str
......
...@@ -23,8 +23,13 @@ except ImportError: ...@@ -23,8 +23,13 @@ except ImportError:
class DistributedLogger: class DistributedLogger:
"""This is a distributed event logger class essentially based on :class:`logging`. """This is a distributed event logger class essentially based on :class:`logging`.
:param name: The name of the logger Args:
:type name: str name (str): The name of the logger.
Note:
The parallel_mode used in ``info``, ``warning``, ``debug`` and ``error``
should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
in `parallel_mode <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/context/parallel_mode.py>`_.
""" """
__instances = dict() __instances = dict()
...@@ -33,10 +38,10 @@ class DistributedLogger: ...@@ -33,10 +38,10 @@ class DistributedLogger:
def get_instance(name: str): def get_instance(name: str):
"""Get the unique single logger instance based on name. """Get the unique single logger instance based on name.
:param name: The name of the logger Args:
:type name: str name (str): The name of the logger.
:return: A DistributedLogger object Returns:
:rtype: DistributedLogger DistributedLogger: A DistributedLogger object
""" """
if name in DistributedLogger.__instances: if name in DistributedLogger.__instances:
return DistributedLogger.__instances[name] return DistributedLogger.__instances[name]
...@@ -73,8 +78,8 @@ class DistributedLogger: ...@@ -73,8 +78,8 @@ class DistributedLogger:
def set_level(self, level: str): def set_level(self, level: str):
"""Set the logging level """Set the logging level
:param level: Can only be INFO, DEBUG, WARNING and ERROR Args:
:type level: str level (str): Can only be INFO, DEBUG, WARNING and ERROR.
""" """
self._check_valid_logging_level(level) self._check_valid_logging_level(level)
self._logger.setLevel(getattr(logging, level)) self._logger.setLevel(getattr(logging, level))
...@@ -82,14 +87,11 @@ class DistributedLogger: ...@@ -82,14 +87,11 @@ class DistributedLogger:
def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None): def log_to_file(self, path: Union[str, Path], mode: str = 'a', level: str = 'INFO', suffix: str = None):
"""Save the logs to file """Save the logs to file
:param path: The file to save the log Args:
:type path: A string or pathlib.Path object path (A string or pathlib.Path object): The file to save the log.
:param mode: The mode to write log into the file mode (str): The mode to write log into the file.
:type mode: str level (str): Can only be INFO, DEBUG, WARNING and ERROR.
:param level: Can only be INFO, DEBUG, WARNING and ERROR suffix (str): The suffix string of log's name.
:type level: str
:param suffix: The suffix string of log's name
:type suffix: str
""" """
assert isinstance(path, (str, Path)), \ assert isinstance(path, (str, Path)), \
f'expected argument path to be type str or Path, but got {type(path)}' f'expected argument path to be type str or Path, but got {type(path)}'
...@@ -131,12 +133,11 @@ class DistributedLogger: ...@@ -131,12 +133,11 @@ class DistributedLogger:
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log an info message. """Log an info message.
:param message: The message to be logged Args:
:type message: str message (str): The message to be logged.
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode` The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
:param ranks: List of parallel ranks ranks (List): List of parallel ranks.
:type ranks: list
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('info', message_prefix, parallel_mode, ranks) self._log('info', message_prefix, parallel_mode, ranks)
...@@ -145,12 +146,11 @@ class DistributedLogger: ...@@ -145,12 +146,11 @@ class DistributedLogger:
def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def warning(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log a warning message. """Log a warning message.
:param message: The message to be logged Args:
:type message: str message (str): The message to be logged.
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode` The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
:param ranks: List of parallel ranks ranks (List): List of parallel ranks.
:type ranks: list
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('warning', message_prefix, parallel_mode, ranks) self._log('warning', message_prefix, parallel_mode, ranks)
...@@ -159,12 +159,11 @@ class DistributedLogger: ...@@ -159,12 +159,11 @@ class DistributedLogger:
def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def debug(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log a debug message. """Log a debug message.
:param message: The message to be logged Args:
:type message: str message (str): The message to be logged.
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode` The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
:param ranks: List of parallel ranks ranks (List): List of parallel ranks.
:type ranks: list
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('debug', message_prefix, parallel_mode, ranks) self._log('debug', message_prefix, parallel_mode, ranks)
...@@ -173,12 +172,11 @@ class DistributedLogger: ...@@ -173,12 +172,11 @@ class DistributedLogger:
def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None): def error(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log an error message. """Log an error message.
:param message: The message to be logged Args:
:type message: str message (str): The message to be logged.
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`):
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode` The parallel mode used for logging. Defaults to ParallelMode.GLOBAL.
:param ranks: List of parallel ranks ranks (List): List of parallel ranks.
:type ranks: list
""" """
message_prefix = "{}:{} {}".format(*self.__get_call_info()) message_prefix = "{}:{} {}".format(*self.__get_call_info())
self._log('error', message_prefix, parallel_mode, ranks) self._log('error', message_prefix, parallel_mode, ranks)
......
...@@ -6,6 +6,7 @@ import torch.nn as nn ...@@ -6,6 +6,7 @@ import torch.nn as nn
def zeros_(): def zeros_():
"""Return the initializer filling the input Tensor with the scalar zeros"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.zeros_(tensor) return nn.init.zeros_(tensor)
...@@ -13,6 +14,7 @@ def zeros_(): ...@@ -13,6 +14,7 @@ def zeros_():
def ones_(): def ones_():
"""Return the initializer filling the input Tensor with the scalar ones"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.ones_(tensor) return nn.init.ones_(tensor)
...@@ -20,6 +22,14 @@ def ones_(): ...@@ -20,6 +22,14 @@ def ones_():
def uniform_(a: float = 0., b: float = 1.): def uniform_(a: float = 0., b: float = 1.):
r"""Return the initializer filling the input Tensor with values drawn from the uniform
distribution :math:`\mathcal{U}(a, b)`.
Args:
a (float): the lower bound of the uniform distribution. Defaults 0.0.
b (float): the upper bound of the uniform distribution. Defaults 1.0.
"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.uniform_(tensor, a, b) return nn.init.uniform_(tensor, a, b)
...@@ -27,6 +37,15 @@ def uniform_(a: float = 0., b: float = 1.): ...@@ -27,6 +37,15 @@ def uniform_(a: float = 0., b: float = 1.):
def normal_(mean: float = 0., std: float = 1.): def normal_(mean: float = 0., std: float = 1.):
r"""Return the initializer filling the input Tensor with values drawn from the normal distribution
.. math::
\mathcal{N}(\text{mean}, \text{std}^2)
Args:
mean (float): the mean of the normal distribution. Defaults 0.0.
std (float): the standard deviation of the normal distribution. Defaults 1.0.
"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.normal_(tensor, mean, std) return nn.init.normal_(tensor, mean, std)
...@@ -34,6 +53,19 @@ def normal_(mean: float = 0., std: float = 1.): ...@@ -34,6 +53,19 @@ def normal_(mean: float = 0., std: float = 1.):
def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.): def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = 2.):
r"""Return the initializer filling the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
mean (float): the mean of the normal distribution. Defaults 0.0.
std (float): the standard deviation of the normal distribution. Defaults 1.0.
a (float): the minimum cutoff value. Defaults -2.0.
b (float): the maximum cutoff value. Defaults 2.0.
"""
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
return nn.init.trunc_normal_(tensor, mean, std, a, b) return nn.init.trunc_normal_(tensor, mean, std, a, b)
...@@ -41,6 +73,26 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float = ...@@ -41,6 +73,26 @@ def trunc_normal_(mean: float = 0., std: float = 1., a: float = -2., b: float =
def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Return the initializer filling the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan_mode}}}
Also known as 'He initialization'.
Args:
a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``).
mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape: if 0 in tensor.shape:
...@@ -64,6 +116,26 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -64,6 +116,26 @@ def kaiming_uniform_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
r"""Return the initializer filling the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
normal distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan_mode}}}
Also known as 'He initialization'.
Args:
a (int): the negative slope of the rectifier used after this layer (only used with ``'leaky_relu'``).
mode (str, optional): either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity (str, optional): the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
if 0 in tensor.shape: if 0 in tensor.shape:
...@@ -86,6 +158,23 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'): ...@@ -86,6 +158,23 @@ def kaiming_normal_(a=0, mode='fan_in', nonlinearity='leaky_relu'):
def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.): def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1.):
r"""Return the initializer filling the input `Tensor` with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010), using a uniform
distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-a, a)` where
.. math::
a = \text{gain} \times \sqrt{\frac{6}{\text{fan_in} + \text{fan_out}}}
Also known as 'Glorot initialization'.
Args:
a (float, optional): an optional scaling factor used to calculate uniform
bounds from standard deviation. Defaults ``math.sqrt(3.)``.
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.' assert fan_in is not None, 'Fan_in is not provided.'
...@@ -102,6 +191,21 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1 ...@@ -102,6 +191,21 @@ def xavier_uniform_(a: float = math.sqrt(3.), scale: float = 2., gain: float = 1
def xavier_normal_(scale: float = 2., gain: float = 1.): def xavier_normal_(scale: float = 2., gain: float = 1.):
r"""Return the initializer filling the input `Tensor` with values according to the method
described in `Understanding the difficulty of training deep feedforward
neural networks` - Glorot, X. & Bengio, Y. (2010), using a normal
distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan_in} + \text{fan_out}}}
Also known as 'Glorot initialization'.
Args:
scale (float, optional): an optional scaling factor used to calculate standard deviation. Defaults 2.0.
gain (float, optional): an optional scaling factor. Defaults 1.0.
"""
# adapted from torch.nn.init # adapted from torch.nn.init
def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None): def initializer(tensor: Tensor, fan_in: int = None, fan_out: int = None):
assert fan_in is not None, 'Fan_in is not provided.' assert fan_in is not None, 'Fan_in is not provided.'
...@@ -137,4 +241,4 @@ def lecun_normal_(): ...@@ -137,4 +241,4 @@ def lecun_normal_():
std = math.sqrt(1.0 / fan_in) std = math.sqrt(1.0 / fan_in)
return nn.init.trunc_normal_(tensor, std=std / .87962566103423978) return nn.init.trunc_normal_(tensor, std=std / .87962566103423978)
return initializer return initializer
\ No newline at end of file
...@@ -6,13 +6,11 @@ from ..utils import get_tensor_parallel_mode ...@@ -6,13 +6,11 @@ from ..utils import get_tensor_parallel_mode
class Dropout(nn.Module): class Dropout(nn.Module):
""" """Dropout layer of colossalai.
Dropout layer of colossalai
:param p: dropout rate, defaults to 0.5 Args:
:type p: float, optional p (float, optional): probability of an element to be zeroed, defaults 0.5.
:param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False`` inplace (bool, optional): whether to do dropout in-place, default to be False.
:type inplace: bool, optional
""" """
def __init__(self, p: float = 0.5, inplace: bool = False) -> None: def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__() super().__init__()
......
...@@ -35,21 +35,33 @@ _parallel_patchembedding = { ...@@ -35,21 +35,33 @@ _parallel_patchembedding = {
class Embedding(nn.Module): class Embedding(nn.Module):
""" r"""Embedding for colossalai.
Embedding for colossalai
Args:
:param num_embeddings: number of embeddings num_embeddings (int): number of embeddings.
:type num_embeddings: int embedding_dim (int): dimension of embedding.
:param embedding_dim: dimension of embedding padding_idx (int, optional): If specified, the entries at padding_idx do not contribute to the gradient;
:type embedding_dim: int therefore, the embedding vector at padding_idx is not updated during training,
:param padding_idx: index of padding, defaults to None i.e. it remains as a fixed “pad”, defaults to None.
:type padding_idx: int, optional dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
:param dtype: The dtype of parameters, defaults to None weight_initializer (:class:`typing.Callable`, optional):
:type dtype: torch.dtype, optional he initializer of weight, defaults to normal initializer.
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional The ``args`` and ``kwargs`` used in :class:`torch.nn.functional.embedding` should contain:
:param args: Args used in F.embedding ::
:param kwargs: Kwargs used in F.embedding
max_norm (float, optional): If given, each embedding vector with norm larger than max_norm is
renormalized to have norm max_norm. Note: this will modify weight in-place.
norm_type (float, optional): The p of the p-norm to compute for the max_norm option. Default 2.
scale_grad_by_freq (bool, optional): If given, this will scale gradients by the inverse
of frequency of the words in the mini-batch. Default False.
sparse (bool, optional): If True, gradient w.r.t. weight will be a sparse tensor. Default False.
More details about ``args`` and ``kwargs`` could be found in
`Embedding <https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html#torch.nn.functional.embedding>`_.
More details about ``initializer`` please refer to
`init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_
""" """
def __init__(self, def __init__(self,
...@@ -97,27 +109,24 @@ class Embedding(nn.Module): ...@@ -97,27 +109,24 @@ class Embedding(nn.Module):
class PatchEmbedding(nn.Module): class PatchEmbedding(nn.Module):
""" """2D Image to Patch Embedding.
2D Image to Patch Embedding
Args:
:param img_size: image size img_size (int): image size.
:type img_size: int patch_size (int): patch size.
:param patch_size: patch size in_chans (int): number of channels of input image.
:type patch_size: int embed_size (int): size of embedding.
:param in_chans: number of channels of input image dtype (:class:`torch.dtype`, optional): The dtype of parameters, defaults to None.
:type in_chans: int flatten (bool, optional): whether to flatten output tensor, defaults to True.
:param embed_size: size of embedding weight_initializer (:class:`typing.Callable`, optional):
:type embed_size: int The initializer of weight, defaults to kaiming uniform initializer.
:param dtype: The dtype of parameters, defaults to None bias_initializer (:class:`typing.Callable`, optional):
:type dtype: torch.dtype, optional The initializer of bias, defaults to xavier uniform initializer.
:param flatten: whether to flatten output tensor, defaults to True position_embed_initializer (:class:`typing.Callable`, optional):
:type flatten: bool, optional The initializer of position embedding, defaults to zeros initializer.
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional More details about ``initializer`` please refer to
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer `init <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/nn/init.py>`_.
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
""" """
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment