Unverified Commit 0f8c7f98 authored by HELSON's avatar HELSON Committed by GitHub
Browse files

Fixed docstring in colossalai (#171)

parent e2089c5c
......@@ -10,7 +10,7 @@ from ...context.parallel_mode import ParallelMode
@GRADIENT_HANDLER.register_module
class MoeGradientHandler(BaseGradientHandler):
"""A helper class to handle all-reduce operations in a data parallel group and
moe tensor parallel. A all-reduce collective communication will be operated in
moe model parallel. A all-reduce collective communication will be operated in
:func:`handle_gradient` among a data parallel group.
For better performance, it bucketizes the gradients of all parameters that are
the same type to improve the efficiency of communication.
......@@ -19,7 +19,7 @@ class MoeGradientHandler(BaseGradientHandler):
def handle_gradient(self):
"""A method running an all-reduce operation in a data parallel group.
Then running an all-reduce operation for all parameters in experts
across moe tensor parallel group
across moe model parallel group
"""
moe_data = moe_env.data_parallel_size
global_data = gpc.data_parallel_size
......
......@@ -2,4 +2,4 @@ from ._base_schedule import BaseSchedule
from ._pipeline_schedule import PipelineSchedule, InterleavedPipelineSchedule
from ._non_pipeline_schedule import NonPipelineSchedule
__all__ = ['BaseSchedule', 'PipelineSchedule', 'NonPipelineSchedule', 'InterleavedPipelineSchedule']
__all__ = ['BaseSchedule', 'NonPipelineSchedule', 'PipelineSchedule', 'InterleavedPipelineSchedule']
......@@ -5,12 +5,13 @@ from abc import ABC, abstractmethod
import torch
from typing import Iterable, Callable
from typing import Iterable, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC):
"""A basic helper class to control the process of training or evaluation.
It mainly composes of forward_backward_step for gradient backward and
......@@ -46,6 +47,11 @@ class BaseSchedule(ABC):
"""Loads a batch from data iterator. It returns the data and labels which are
already in the same GPU as where the model's.
:param data_iter: Data iterator from which get a batch of data
:type data_iter: DataIter
:param to_gpu: Whether the data should be moved to GPU
:type to_gpu: bool, optional
:return: (data, label)
:rtype: (:class:`Tensor`, :class:`torch.Tensor`)
"""
......@@ -62,13 +68,12 @@ class BaseSchedule(ABC):
if isinstance(data, torch.Tensor):
self.batch_size = data.size(0)
else:
self.batch_size = next(iter(data.values())).size(0)
self.batch_size = next(iter(data.values())).size(0)
data, label = split_batch(data), split_batch(label)
if to_gpu:
return self._move_to_device(data), self._move_to_device(label)
return data, label
def pre_processing(self, engine: Engine):
"""To perform actions before running the schedule.
"""
......@@ -85,11 +90,15 @@ class BaseSchedule(ABC):
"""The process function over a batch of dataset for training or evaluation.
:param engine: Colossalai training engine
:param inputs: input data
:param labels: ground truth
:type engine: colossalai.engine.Engine
:param data_iter: Data iterator from which get a batch of data
:type data_iter: DataIter
:param forward_only: If True, the process won't include backward
:type forward_only: bool
:param return_loss: If False, the loss won't be returned
:type return_loss: bool, optional
:param return_output_label: If False, the output and label won't be returned
:type return_output_label: bool, optional
"""
pass
......@@ -105,7 +114,7 @@ class BaseSchedule(ABC):
assert isinstance(outputs, (torch.Tensor, list, tuple)
), f'Expect output of model is (torch.Tensor, list, tuple), got {type(outputs)}'
if isinstance(outputs, torch.Tensor):
outputs = (outputs, )
outputs = (outputs,)
if isinstance(labels, torch.Tensor):
return engine.criterion(*outputs, labels)
else:
......
......@@ -15,10 +15,6 @@ class NonPipelineSchedule(BaseSchedule):
During one process, it loads a batch of dataset and feeds it to the model.
After getting the output and calculating the loss, it will use :meth:`step`
to update the parameters if it is in training mode.
:param amp_type: The type of automatic mixed precision
:param amp_config: The configuration of automatic mixed procision
:type amp_type: AMP_TYPE
:type amp_config: dict
"""
def forward_backward_step(self,
......@@ -29,6 +25,7 @@ class NonPipelineSchedule(BaseSchedule):
return_output_label: bool = True):
"""The process function that loads loads a batch of dataset and feeds it to the model.
The returned labels and loss will None if :attr:`return_loss` is False.
:param engine: Model for training and inference
:param data_iter: Data iterator of the dataloader, e.g. iter(dataloader)
:param forward_only: If True, the model is run for the forward pass, else back propagation will be executed
......
......@@ -44,9 +44,11 @@ class PipelineSchedule(BaseSchedule):
:param num_microbatches: The number of microbatches
:type num_microbatches: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:type batch_data_process_func: Callable, optional
:param tensor_shape: Specified shape in pipeline communication
:type tensor_shape: torch.Size, optional
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
:type scatter_gather_tensors: bool, optional
"""
def __init__(self,
......@@ -130,12 +132,16 @@ class PipelineSchedule(BaseSchedule):
is obtained from data_iterator, otherwise the passed-in input_tensor is used.
Returns output tensor. This is a helper function and can be ignored by users.
:param engine: your engine object
:param engine: Your engine object
:type engine: colossalai.engine.Engine
:param input_tensor: input tensor for this pipeline stage
:param input_tensor: Input tensor for this pipeline stage
:type input_tensor: :class:`torch.Tensor`
:param return_tensors: a list of tensors to return
:param return_tensors: A list of tensors to return
:type return_tensors: List[:class:`torch.Tensor`]
:param return_output_label: Whether returns output labels
:type return_output_label: bool, optional
:param accum_loss: Where accumulated loss stores
:type accum_loss: optional
:return: output or the loss value of the current pipeline stage
:rtype: :class:`torch.Tensor`
......@@ -205,13 +211,13 @@ class PipelineSchedule(BaseSchedule):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
:param engine: your engine object
:param engine: Your engine object
:type engine: colossalai.engine.Engine
:param data_iter: dataloader as the form of an iterator, obtained by calling iter(dataloader)
:param data_iter: Dataloader as the form of an iterator, obtained by calling iter(dataloader)
:type data_iter: Iterable
:param forward_only: whether run forward step only. Default is false. If true, no backward will be run.
:param forward_only: Whether run forward step only. Default is false. If true, no backward will be run.
:type forward_only: bool
:param return_loss: whether returns the loss value. Default is true.
:param return_loss: Whether returns the loss value. Default is true.
:type return_loss: bool
:param return_output_label: If False, the output and label won't be returned
:type return_output_label: bool
......@@ -357,9 +363,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
:param num_model_chunks: The number of model chunks
:type num_model_chunks: int
:param batch_data_process_func: The preprocessing function which receives a batch of data, and it will be executed in `load_batch`
:type batch_data_process_func: Callable
:type batch_data_process_func: Callable, optional
:param tensor_shape: Specified shape in pipeline communication
:type tensor_shape: torch.Size, optional
:param scatter_gather_tensors: If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization
:type scatter_gather_tensors: bool
:type scatter_gather_tensors: bool, optional
"""
assert num_microbatches % gpc.get_world_size(ParallelMode.PIPELINE) == 0, \
'num_microbatches must be an integer multiple of pipeline parallel world size'
......@@ -425,7 +433,19 @@ class InterleavedPipelineSchedule(PipelineSchedule):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.
Returns dictionary with losses if the last stage, empty dict otherwise."""
Returns dictionary with losses if the last stage, empty dict otherwise.
:param engine: Your engine object
:type engine: colossalai.engine.Engine
:param data_iter: Dataloader as the form of an iterator, obtained by calling iter(dataloader)
:type data_iter: Iterable
:param forward_only: Whether run forward step only. Default is false. If true, no backward will be run.
:type forward_only: bool
:param return_loss: Whether returns the loss value. Default is true.
:type return_loss: bool
:param return_output_label: If False, the output and label won't be returned
:type return_output_label: bool
"""
assert forward_only or return_loss, \
'The argument \'return_loss\' has to be True when \'forward_only\' is False, but got False.'
self.load_batch(data_iter)
......
class MoeEnv:
"""Moe enviroment variable.
"""Moe enviroment variables.
"""
def __init__(self):
......
......@@ -29,12 +29,12 @@ from colossalai.global_variables import moe_env
def get_default_parser():
'''Reads user command line and uses an argument parser to parse the input arguments.
"""Reads user command line and uses an argument parser to parse the input arguments.
Input arguments include configuration, host, port, world size, local rank, backend for torch.distributed.
:return: returns the parser with the default arguments, the user may add customized arguments into this parser
:return: Returns the parser with the default arguments, the user may add customized arguments into this parser
:rtype: Namespace
'''
"""
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, help='path to the config file')
parser.add_argument('--host',
......@@ -64,28 +64,30 @@ def launch(config: Union[str, Path, Config, Dict],
local_rank: int = None,
seed: int = 1024,
verbose: bool = True):
'''This function first parses the configuration arguments, using :func:parse_args() in case one of the input arguments are not given.
Then initialize and set distributed environment by calling global_context's functions.
"""This function first parses the configuration arguments, using :func:`parse_args()` in case one of the input
arguments are not given. Then initialize and set distributed environment by calling global_context's functions.
:param config: config file or config file path are both acceptable
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param rank: rank for the default process group
:param rank: Rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:param world_size: World size of the default process group
:type world_size: int
:param host: the master address for distributed training
:param host: The master address for distributed training
:type host: str
:param port: the master port for distributed training
:param port: The master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
:param local_rank: rank for the process on the node and is used to set the default CUDA device,
defaults to None. If local_rank = None, the default device ordinal will be calculated automatically
:param backend: Backend for torch.distributed
:type backend: str, optional
:param local_rank: Rank for the process on the node and is used to set the default CUDA device, defaults to None.
If local_rank = None, the default device ordinal will be calculated automatically
:type local_rank: int, optional
:param verbose: whether to print logs
:type verbose: bool
:raises Exception: raise exception when config type is wrong
'''
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
:raises Exception: Raise exception when config type is wrong
"""
gpc.verbose = verbose
# set config
......@@ -123,20 +125,22 @@ def launch_from_slurm(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
'''A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
"""A wrapper for colossalai.launch for SLURM launcher by reading rank and world size from the environment variables
set by SLURM
:param config: config file or config file path are both acceptable
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param host: the master address for distributed training
:param host: The master address for distributed training
:type host: str
:param port: the master port for distributed training
:param port: The master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
:param verbose: whether to print logs
:type verbose: bool
'''
:param backend: Backend for torch.distributed
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
"""
rank = int(os.environ['SLURM_PROCID'])
world_size = int(os.environ['SLURM_NPROCS'])
launch(config=config,
......@@ -155,20 +159,22 @@ def launch_from_openmpi(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
'''A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
"""A wrapper for colossalai.launch for OpenMPI launcher by reading rank and world size from the environment variables
set by OpenMPI
:param config: config file or config file path are both acceptable
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param host: the master address for distributed training
:param host: The master address for distributed training
:type host: str
:param port: the master port for distributed training
:param port: The master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
:param verbose: whether to print logs
:type verbose: bool
'''
:param backend: Backend for torch.distributed
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
"""
rank = int(os.environ['OMPI_COMM_WORLD_RANK'])
local_rank = int(os.environ['OMPI_COMM_WORLD_LOCAL_RANK'])
world_size = int(os.environ['OMPI_COMM_WORLD_SIZE'])
......@@ -187,20 +193,18 @@ def launch_from_torch(config: Union[str, Path, Config, Dict],
backend: str = 'nccl',
seed: int = 1024,
verbose: bool = True):
'''A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
"""A wrapper for colossalai.launch for torchrun or torch.distributed.launch by reading rank and world size
from the environment variables set by PyTorch
:param config: config file or config file path are both acceptable
:param config: Config file or config file path are both acceptable
:type config: Union[str, dict, Config]
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
:param verbose: whether to print logs
:type verbose: bool
'''
:param backend: Backend for torch.distributed
:type backend: str, optional
:param seed: Specified random seed for every processes
:type seed: int, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
"""
rank = int(os.environ['RANK'])
local_rank = int(os.environ['LOCAL_RANK'])
world_size = int(os.environ['WORLD_SIZE'])
......@@ -225,25 +229,26 @@ def initialize(model: Union[nn.Module, List[nn.Module]],
lr_scheduler: _LRScheduler = None,
verbose: bool = True
) -> Tuple[Engine, DataLoader, DataLoader, _LRScheduler]:
''' Core function to wrap the essential training components with our functionality based on the config which is loaded into gpc.config.
"""Core function to wrap the essential training components with our functionality based on the config which is
loaded into gpc.config.
:param model: your model instance
:param model: Your model instance
:type model: :class:`torch.nn.Module`
:param optimizer: your optimizer instance
:param optimizer: Your optimizer instance
:type optimizer: :class:`torch.optim.optimizer.Optimizer`
:param criterion: your criterion instance
:param criterion: Your criterion instance
:type criterion: :class:`torch.nn.modules.loss._Loss`
:param train_dataloader: dataloader for training data
:type train_dataloader: :class:`torch.utils.data.DataLoader`
:param train_dataloader: dataloader for testing data
:type train_dataloader: :class:`torch.utils.data.DataLoader`
:param lr_scheduler: your lr scheduler instance
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`
:param verbose: whether to print logs
:type verbose: bool
:param train_dataloader: Dataloader for training
:type train_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param test_dataloader: Dataloader for testing
:type test_dataloader: :class:`torch.utils.data.DataLoader`, optional
:param lr_scheduler: Your lr scheduler instance
:type lr_scheduler: :class:`torch.nn.lr_scheduler._LRScheduler`, optional
:param verbose: Whether to print logs
:type verbose: bool, optional
:return: (engine, train_dataloader, test_dataloader, lr_scheduler)
:rtype: tuple
'''
:rtype: Tuple
"""
# get logger
logger = get_dist_logger()
gpc.verbose = verbose
......
......@@ -106,8 +106,10 @@ class MultiHeadAttention(nn.Module):
"""Initialize the MultiHeadAttention.
Static variable:
layer_id: The layer-index counter starting from 0 and incrementing by 1 every time a layer object is instantiated,
e.g. if a model has 24 transformer layers, layer_id goes from 0 to 23.
Arguments:
hidden_size: Total dimension of hidden_size.
nhead: Number of parallel attention heads.
......
......@@ -14,9 +14,10 @@ class AttnMaskType(enum.Enum):
class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
1. Scale the tensor.
2. Apply upper triangular mask (typically used in gpt models).
3. Perform softmax.
"""
@staticmethod
......@@ -52,9 +53,10 @@ class ScaledUpperTriangMaskedSoftmax(torch.autograd.Function):
class ScaledMaskedSoftmax(torch.autograd.Function):
"""
Fused operation which performs following three operations in sequence
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
1. Scale the tensor.
2. Apply the mask.
3. Perform softmax.
"""
@staticmethod
......@@ -87,16 +89,16 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
class FusedScaleMaskSoftmax(nn.Module):
"""
fused operation: scaling + mask + softmax
Fused operation: scaling + mask + softmax
Arguments:
input_in_fp16: flag to indicate if input in fp16 data format.
input_in_bf16: flag to indicate if input in bf16 data format.
attn_mask_type: attention mask type (pad or causal)
scaled_masked_softmax_fusion: flag to indicate user want to use softmax fusion
mask_func: mask function to be applied.
softmax_in_fp32: if true, softmax in performed at fp32 precision.
scale: scaling factor used in input tensor scaling.
input_in_fp16: Flag to indicate if input in fp16 data format.
input_in_bf16: Flag to indicate if input in bf16 data format.
attn_mask_type: Attention mask type (pad or causal)
scaled_masked_softmax_fusion: Flag to indicate user want to use softmax fusion
mask_func: Mask function to be applied.
softmax_in_fp32: If True, softmax in performed at fp32 precision.
scale: Scaling factor used in input tensor scaling.
"""
def __init__(
......
......@@ -25,9 +25,10 @@ class DistributedLogger:
@staticmethod
def get_instance(name: str):
"""Get the unique single logger instance based on name.
:param name: The name of the logger
:type name: str
:return: a DistributedLogger object
:return: A DistributedLogger object
:rtype: DistributedLogger
"""
if name in DistributedLogger.__instances:
......@@ -50,7 +51,8 @@ class DistributedLogger:
def set_level(self, level: str):
"""Set the logging level
:param level: can only be INFO, DEBUG, WARNING and ERROR
:param level: Can only be INFO, DEBUG, WARNING and ERROR
:type level: str
"""
self._check_valid_logging_level(level)
......@@ -62,12 +64,15 @@ class DistributedLogger:
level: str = 'INFO',
suffix: str = None):
"""Save the logs to file
:param path: the file to save the log
:type path: a string or pathlib.Path object
:param mode: the mode to write log into the file
:param path: The file to save the log
:type path: A string or pathlib.Path object
:param mode: The mode to write log into the file
:type mode: str
:param level: can only be INFO, DEBUG, WARNING and ERROR
:param level: Can only be INFO, DEBUG, WARNING and ERROR
:type level: str
:param suffix: The suffix string of log's name
:type suffix: str
"""
assert isinstance(path, (str, Path)), \
f'expected argument path to be type str or Path, but got {type(path)}'
......@@ -105,12 +110,12 @@ class DistributedLogger:
def info(self, message: str, parallel_mode: ParallelMode = ParallelMode.GLOBAL, ranks: list = None):
"""Log an info message.
:param message:
:type message:
:param parallel_mode:
:type parallel_mode:
:param ranks:
:type ranks:
:param message: The message to be logged
:type message: str
:param parallel_mode: The parallel mode used for logging. Defaults to ParallelMode.GLOBAL
:type parallel_mode: :class:`colossalai.context.parallel_mode.ParallelMode`
:param ranks: List of parallel ranks
:type ranks: list
"""
self._log('info', message, parallel_mode, ranks)
......
......@@ -37,6 +37,8 @@ class Embedding(nn.Module):
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
......
......@@ -40,6 +40,7 @@ class Linear(nn.Module):
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param kwargs: Kwargs used for initialization
"""
def __init__(self,
in_features: int,
......
......@@ -12,8 +12,12 @@ from ._operation import AllToAll
class NormalNoiseGenerator:
"""Generates a random noisy mask for logtis tensor.
All noise is generated from a normal distribution (0, 1 / E^2), where
E = the number of experts.
:param num_experts: The number of experts
:type num_experts: int
"""
def __init__(self, num_experts: int):
......@@ -31,6 +35,12 @@ class Experts(nn.Module):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
is a instence of the class, 'expert' in initialization parameters.
:param expert: The class of all experts
:param num_experts: The number of experts
:param expert_args: Args used to initialize experts
:type num_experts: int
"""
def __init__(self, expert, num_experts, **expert_args):
......@@ -63,6 +73,14 @@ class Top1Router(nn.Module):
"""Top1 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about Switch Transformer
of Google.
:param capacity_factor: Capacity factor in routing
:param min_capacity: The minimum number of the capacity of each expert
:param noisy_func: Noisy function used in logits
:type capacity_factor: float
:type min_capacity: int
:type noisy_func: Callable, optional
"""
def __init__(self,
......@@ -127,6 +145,12 @@ class Top1Router(nn.Module):
class Top2Router(nn.Module):
"""Top2 router that returns the dispatch mask [s, e, c] and combine weight [s, e, c]
for routing usage. More deailted function can be found in the paper about ViT-MoE.
:param capacity_factor: Capacity factor in routing
:param noisy_func: Noisy function used in logits
:type capacity_factor: float
:type noisy_func: Callable, optional
"""
def __init__(self, capacity_factor: float, noisy_func=None):
......@@ -189,6 +213,16 @@ class MoeLayer(nn.Module):
to router all tokens, is mainly used to exchange all tokens for every expert across
the moe tensor group by all to all comunication. Then it will get the output of all
experts and exchange the output. At last returns the output of the moe system.
:param dim_model: Dimension of model
:param num_experts: The number of experts
:param router: Instance of router used in routing
:param experts: Instance of experts generated by Expert
:type dim_model: int
:type num_experts: int
:type router: nn.Module
:type experts: nn.Module
"""
def __init__(self,
......
......@@ -186,6 +186,12 @@ class Linear1D_Col(ParallelLayer):
to all GPUs, otherwise, every GPU will have its output
which is :math:`Y_i = XA_i`, defaults to False
:type gather_output: bool, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
......@@ -268,6 +274,12 @@ class Linear1D_Row(ParallelLayer):
:type dtype: torch.dtype, optional
:param parallel_input: If set to ``True``, it's assumed that the input is splitted, defaults to False
:type parallel_input: bool, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
......@@ -383,6 +395,8 @@ class Embedding1D(ParallelLayer):
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
......
......@@ -771,6 +771,17 @@ class SplitFirst(torch.autograd.Function):
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
"""Splits 2D tensor in specified dimension across cols
:param input_: Input tensor
:param dim: Specified dimension in which to split
:type input_: torch.Tensor
:type dim: int, optional
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2D_COL),
......@@ -778,13 +789,7 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2d(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""All-reduce the input from the model parallel region.
"""
@staticmethod
def symbolic(graph, input_, reduce_mean: bool = False):
......@@ -797,6 +802,12 @@ class reduce_by_batch_2d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_, reduce_mean: bool = False):
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
ctx.reduce_mean = reduce_mean
if reduce_mean:
......
......@@ -303,6 +303,8 @@ class Embedding2D(ParallelLayer):
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
......
......@@ -733,6 +733,17 @@ class SplitFirst(torch.autograd.Function):
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
"""Splits 2P5D tensor in specified dimension across cols
:param input_: Input tensor
:param dim: Specified dimension in which to split
:type input_: torch.Tensor
:type dim: int, optional
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
if input_.size(dim) <= 1:
return input_
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
......@@ -740,13 +751,7 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2p5d(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""All-reduce the input from the model parallel region.
"""
@staticmethod
def symbolic(graph, input_, reduce_mean: bool = False):
......@@ -759,6 +764,12 @@ class reduce_by_batch_2p5d(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_, reduce_mean: bool = False):
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_mean = reduce_mean
if reduce_mean:
......
......@@ -315,6 +315,8 @@ class Embedding2p5D(ParallelLayer):
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
......
......@@ -240,6 +240,21 @@ def split_tensor_3d(input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
"""Splits 3D tensor in specified dimension
:param input_: Input tensor
:param dim: Specified dimension in which to split
:param input_parallel_mode: Input parallel mode
:param weight_parallel_mode: Weight parallel mode
:type input_: torch.Tensor
:type dim: int, optional
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
if input_.size(dim) <= 1:
return input_
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
......@@ -250,17 +265,7 @@ def split_tensor_3d(input_: Tensor,
class reduce_by_batch_3d(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:type reduce_mean: int, optional
"""All-reduce the input from the model parallel region.
"""
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
......@@ -269,6 +274,16 @@ class reduce_by_batch_3d(torch.autograd.Function):
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean
......
......@@ -338,6 +338,8 @@ class Embedding3D(ParallelLayer):
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment