Unverified Commit 4a3d3446 authored by BoxiangW's avatar BoxiangW Committed by GitHub
Browse files

Update layer integration documentations (#108)

Update the documentations of layer integration

Update _log_hook.py

Update _operation.py
parent 3a61d785
...@@ -9,6 +9,14 @@ from ..utils import get_tensor_parallel_mode ...@@ -9,6 +9,14 @@ from ..utils import get_tensor_parallel_mode
class Dropout(nn.Module): class Dropout(nn.Module):
"""
Dropout layer of colossalai
:param p: dropout rate, defaults to 0.5
:type p: float, optional
:param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False``
:type inplace: bool, optional
"""
def __init__(self, p: float = 0.5, inplace: bool = False) -> None: def __init__(self, p: float = 0.5, inplace: bool = False) -> None:
super().__init__() super().__init__()
self.tensor_parallel = get_tensor_parallel_mode() self.tensor_parallel = get_tensor_parallel_mode()
......
...@@ -24,6 +24,20 @@ _parallel_patchembedding = { ...@@ -24,6 +24,20 @@ _parallel_patchembedding = {
class Embedding(nn.Module): class Embedding(nn.Module):
"""
Embedding for colossalai
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
...@@ -63,6 +77,28 @@ class Embedding(nn.Module): ...@@ -63,6 +77,28 @@ class Embedding(nn.Module):
class PatchEmbedding(nn.Module): class PatchEmbedding(nn.Module):
"""
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
img_size: int, img_size: int,
patch_size: int, patch_size: int,
......
...@@ -25,6 +25,22 @@ _parallel_classifier = { ...@@ -25,6 +25,22 @@ _parallel_classifier = {
class Linear(nn.Module): class Linear(nn.Module):
"""
Linear layer of colossalai
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
...@@ -64,6 +80,22 @@ class Linear(nn.Module): ...@@ -64,6 +80,22 @@ class Linear(nn.Module):
class Classifier(nn.Module): class Classifier(nn.Module):
"""
Classifier layer of colossalai
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of total classes for the dataset
:type num_classes: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__( def __init__(
self, self,
in_features: int, in_features: int,
......
...@@ -15,6 +15,19 @@ _parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm ...@@ -15,6 +15,19 @@ _parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
r"""
Layer Normalization for colossalai
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
......
...@@ -7,6 +7,18 @@ except: ...@@ -7,6 +7,18 @@ except:
class FusedLayerNormAffineFunction1D(torch.autograd.Function): class FusedLayerNormAffineFunction1D(torch.autograd.Function):
r"""
Layernorm
:param input: input maxtrix
:param weight: weight matrix
:param bias: bias matrix
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:param eps: a value added to the denominator for numerical stability
"""
@staticmethod @staticmethod
def forward(ctx, input, weight, bias, normalized_shape, eps): def forward(ctx, input, weight, bias, normalized_shape, eps):
......
...@@ -76,7 +76,12 @@ def _gather(input_, parallel_mode, dim=-1): ...@@ -76,7 +76,12 @@ def _gather(input_, parallel_mode, dim=-1):
class _ReduceGrad(torch.autograd.Function): class _ReduceGrad(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """
Pass the input to the model parallel region.
:param input_: input matrix
:param parallel_mode: parallel mode
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return input_ return input_
...@@ -92,7 +97,12 @@ class _ReduceGrad(torch.autograd.Function): ...@@ -92,7 +97,12 @@ class _ReduceGrad(torch.autograd.Function):
class _ReduceInput(torch.autograd.Function): class _ReduceInput(torch.autograd.Function):
"""All-reduce the input from the model parallel region.""" """
All-reduce the input from the model parallel region.
:param input_: input matrix
:param parallel_mode: parallel mode
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _reduce(input_) return _reduce(input_)
...@@ -107,7 +117,13 @@ class _ReduceInput(torch.autograd.Function): ...@@ -107,7 +117,13 @@ class _ReduceInput(torch.autograd.Function):
class _SplitForwardGatherBackward(torch.autograd.Function): class _SplitForwardGatherBackward(torch.autograd.Function):
"""Split the input and keep only the corresponding chuck to the rank.""" """
Split the input and keep only the corresponding chuck to the rank.
:param input_: input matrix
:param parallel_mode: parallel mode
:param dim: dimension
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _split(input_) return _split(input_)
...@@ -124,7 +140,13 @@ class _SplitForwardGatherBackward(torch.autograd.Function): ...@@ -124,7 +140,13 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
class _GatherForwardSplitBackward(torch.autograd.Function): class _GatherForwardSplitBackward(torch.autograd.Function):
"""Gather the input from model parallel region and concatinate.""" """
Gather the input from model parallel region and concatinate.
:param input_: input matrix
:param parallel_mode: parallel mode
:param dim: dimension
"""
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(graph, input_):
return _gather(input_) return _gather(input_)
......
...@@ -26,6 +26,24 @@ from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_g ...@@ -26,6 +26,24 @@ from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_g
@LAYERS.register_module @LAYERS.register_module
class Linear1D(torch.nn.Module): class Linear1D(torch.nn.Module):
"""
Linear layer for 1D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
...@@ -70,8 +88,24 @@ class Linear1D(torch.nn.Module): ...@@ -70,8 +88,24 @@ class Linear1D(torch.nn.Module):
@LAYERS.register_module @LAYERS.register_module
class Classifier1D(ParallelLayer): class Classifier1D(ParallelLayer):
"""RowLinear with given weight""" """RowLinear with given weight
Classifier of 1D parallelism
:param in_features: size of input features
:type in_features: int
:param num_classes: number of classes in the dataset
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
...@@ -144,7 +178,7 @@ class Linear1D_Col(ParallelLayer): ...@@ -144,7 +178,7 @@ class Linear1D_Col(ParallelLayer):
:type in_features: int :type in_features: int
:param output_size: second dimension of matrix A. :param output_size: second dimension of matrix A.
:type output_size: int :type output_size: int
:param bias: If true, add bias, defaults to True :param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional :type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
...@@ -228,7 +262,7 @@ class Linear1D_Row(ParallelLayer): ...@@ -228,7 +262,7 @@ class Linear1D_Row(ParallelLayer):
:type in_features: int :type in_features: int
:param out_features: size of each output sample :param out_features: size of each output sample
:type out_features: int :type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True :param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional :type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
...@@ -303,7 +337,16 @@ class Linear1D_Row(ParallelLayer): ...@@ -303,7 +337,16 @@ class Linear1D_Row(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module): class MixedFusedLayerNorm1D(torch.nn.Module):
""" Experimental r"""
Layer Normalization for 1D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
""" """
def __init__(self, normalized_shape, eps=1e-5): def __init__(self, normalized_shape, eps=1e-5):
...@@ -327,6 +370,20 @@ class MixedFusedLayerNorm1D(torch.nn.Module): ...@@ -327,6 +370,20 @@ class MixedFusedLayerNorm1D(torch.nn.Module):
@LAYERS.register_module @LAYERS.register_module
class Embedding1D(ParallelLayer): class Embedding1D(ParallelLayer):
"""
Embedding for 1D parallelism
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
...@@ -377,6 +434,14 @@ class Embedding1D(ParallelLayer): ...@@ -377,6 +434,14 @@ class Embedding1D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Dropout1D(ParallelLayer): class Dropout1D(ParallelLayer):
"""
Dropout layer of 1D parallelism
:param p: dropout rate, defaults to 0.5
:type p: float, optional
:param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False``
:type inplace: bool, optional
"""
def __init__(self, p: float = 0.5, inplace: bool = False): def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__() super().__init__()
self.parallel_input = get_parallel_input() self.parallel_input = get_parallel_input()
......
...@@ -20,7 +20,8 @@ def matmul_2d( ...@@ -20,7 +20,8 @@ def matmul_2d(
row_parallel_mode=ParallelMode.PARALLEL_2D_ROW, row_parallel_mode=ParallelMode.PARALLEL_2D_ROW,
col_parallel_mode=ParallelMode.PARALLEL_2D_COL, col_parallel_mode=ParallelMode.PARALLEL_2D_COL,
): ):
"""Matrix multiplication for 2D parallelism """
Matrix multiplication for 2D parallelism
:param a: matrix :math:`A` :param a: matrix :math:`A`
:type a: torch.tensor :type a: torch.tensor
:param b: matrix :math:`B` :param b: matrix :math:`B`
...@@ -56,7 +57,35 @@ def matmul_2d( ...@@ -56,7 +57,35 @@ def matmul_2d(
class classifier_2d(torch.autograd.Function): class classifier_2d(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB` """
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -130,7 +159,33 @@ class classifier_2d(torch.autograd.Function): ...@@ -130,7 +159,33 @@ class classifier_2d(torch.autograd.Function):
class Matmul_AB_2D(torch.autograd.Function): class Matmul_AB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB` """
Matrix multiplication for :math:`C = AB`
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -238,7 +293,33 @@ class Matmul_AB_2D(torch.autograd.Function): ...@@ -238,7 +293,33 @@ class Matmul_AB_2D(torch.autograd.Function):
class Matmul_ABT_2D(torch.autograd.Function): class Matmul_ABT_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T` """
Matrix multiplication for :math:`C = AB^T`
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -352,7 +433,33 @@ class Matmul_ABT_2D(torch.autograd.Function): ...@@ -352,7 +433,33 @@ class Matmul_ABT_2D(torch.autograd.Function):
class Matmul_ATB_2D(torch.autograd.Function): class Matmul_ATB_2D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB` """
Matrix multiplication for :math:`C = A^TB`
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -466,7 +573,33 @@ class Matmul_ATB_2D(torch.autograd.Function): ...@@ -466,7 +573,33 @@ class Matmul_ATB_2D(torch.autograd.Function):
class add_bias_2d(torch.autograd.Function): class add_bias_2d(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b` """
Matrix add bias: :math:`C = A + b`
:param input_: matrix :math:`A`
:type input_: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: size of ouput per partition
:type output_size_per_partition: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -519,9 +652,30 @@ class add_bias_2d(torch.autograd.Function): ...@@ -519,9 +652,30 @@ class add_bias_2d(torch.autograd.Function):
class layernorm_2d(torch.autograd.Function): class layernorm_2d(torch.autograd.Function):
"""
Layernorm
:param input_: input maxtrix
:type input_: torch.tensor
:param E_x: mean
:type E_x: torch.tensor
:param Var_x: variance
:type Var_x: torch.tensor
:param hidden_size: hidden size
:type hidden_size: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode, def forward(ctx: Any,
input_: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor: col_parallel_mode: ParallelMode) -> Tensor:
input_ = input_ - E_x input_ = input_ - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
...@@ -556,6 +710,18 @@ class layernorm_2d(torch.autograd.Function): ...@@ -556,6 +710,18 @@ class layernorm_2d(torch.autograd.Function):
class all_gather_weight_2d(torch.autograd.Function): class all_gather_weight_2d(torch.autograd.Function):
"""
all gather the weight of 2D parallelism
:param inputs: input maxtrix
:type inputs: torch.tensor
:param dim: dimension of all gather
:type dim: int
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
...@@ -574,6 +740,14 @@ class all_gather_weight_2d(torch.autograd.Function): ...@@ -574,6 +740,14 @@ class all_gather_weight_2d(torch.autograd.Function):
class SplitFirst(torch.autograd.Function): class SplitFirst(torch.autograd.Function):
"""
:param inputs: input maxtrix
:type inputs: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
...@@ -604,7 +778,14 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -604,7 +778,14 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2d(torch.autograd.Function): class reduce_by_batch_2d(torch.autograd.Function):
"""All-reduce the input from the model parallel region.""" """
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
@staticmethod @staticmethod
def symbolic(graph, input_, reduce_mean: bool = False): def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
......
...@@ -21,7 +21,8 @@ from ._utils import assert_summa_initialization, get_summa_dim_from_env ...@@ -21,7 +21,8 @@ from ._utils import assert_summa_initialization, get_summa_dim_from_env
@LAYERS.register_module @LAYERS.register_module
class Linear2D(ParallelLayer): class Linear2D(ParallelLayer):
""" Linear layer for 2D parallelism """
Linear layer for 2D parallelism
:param in_features: size of each input sample :param in_features: size of each input sample
:type in_features: int :type in_features: int
...@@ -33,6 +34,10 @@ class Linear2D(ParallelLayer): ...@@ -33,6 +34,10 @@ class Linear2D(ParallelLayer):
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False :param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion, defaults to False
:type skip_bias_add: bool, optional :type skip_bias_add: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
...@@ -113,7 +118,8 @@ class Linear2D(ParallelLayer): ...@@ -113,7 +118,8 @@ class Linear2D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class LayerNorm2D(ParallelLayer): class LayerNorm2D(ParallelLayer):
r"""Layer Normalization for 2D parallelism r"""
Layer Normalization for 2D parallelism
:param normalized_shape: input shape from an expected input :param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
...@@ -184,18 +190,27 @@ class LayerNorm2D(ParallelLayer): ...@@ -184,18 +190,27 @@ class LayerNorm2D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class PatchEmbedding2D(ParallelLayer): class PatchEmbedding2D(ParallelLayer):
""" 2D Image to Patch Embedding """
2D Image to Patch Embedding
:param img_size: iamge size :param img_size: image size
:type img_size: int :type img_size: int
:param patch_size: patch size :param patch_size: patch size
:type patch_size: int :type patch_size: int
:param embed_dim: dimension of embedding :param in_chans: number of channels of input image
:type embed_dim: int :type in_chans: int
:param in_chans: number of channels of input image, defaults to 3 :param embed_size: size of embedding
:type in_chans: int, optional :type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True :param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional :type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
...@@ -275,6 +290,20 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -275,6 +290,20 @@ class PatchEmbedding2D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Embedding2D(ParallelLayer): class Embedding2D(ParallelLayer):
"""
Embedding for 2D parallelism
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
...@@ -325,6 +354,24 @@ class Embedding2D(ParallelLayer): ...@@ -325,6 +354,24 @@ class Embedding2D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Classifier2D(ParallelLayer): class Classifier2D(ParallelLayer):
"""
Classifier for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
......
...@@ -28,7 +28,35 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -28,7 +28,35 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class classifier_2p5d(torch.autograd.Function): class classifier_2p5d(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB` """
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -101,7 +129,35 @@ class classifier_2p5d(torch.autograd.Function): ...@@ -101,7 +129,35 @@ class classifier_2p5d(torch.autograd.Function):
class Matmul_AB_2p5D(torch.autograd.Function): class Matmul_AB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB` """
Matrix multiplication for :math:`C = AB`
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param dep_rank: the rank of depth
:type dep_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -202,7 +258,35 @@ class Matmul_AB_2p5D(torch.autograd.Function): ...@@ -202,7 +258,35 @@ class Matmul_AB_2p5D(torch.autograd.Function):
class Matmul_ABT_2p5D(torch.autograd.Function): class Matmul_ABT_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = AB^T` """
Matrix multiplication for :math:`C = AB^T`
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param dep_rank: the rank of depth
:type dep_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -308,7 +392,35 @@ class Matmul_ABT_2p5D(torch.autograd.Function): ...@@ -308,7 +392,35 @@ class Matmul_ABT_2p5D(torch.autograd.Function):
class Matmul_ATB_2p5D(torch.autograd.Function): class Matmul_ATB_2p5D(torch.autograd.Function):
"""Matrix multiplication for :math:`C = A^TB` """
Matrix multiplication for :math:`C = A^TB`
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param dep_rank: the rank of depth
:type dep_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -411,7 +523,35 @@ class Matmul_ATB_2p5D(torch.autograd.Function): ...@@ -411,7 +523,35 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
class Add_Bias_2p5D(torch.autograd.Function): class Add_Bias_2p5D(torch.autograd.Function):
"""Matrix add bias: :math:`C = A + b` """
Matrix add bias: :math:`C = A + b`
:param input: matrix :math:`A`
:type input: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: output size in each partition
:type output_size_per_partition: int
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
...@@ -482,6 +622,20 @@ class Add_Bias_2p5D(torch.autograd.Function): ...@@ -482,6 +622,20 @@ class Add_Bias_2p5D(torch.autograd.Function):
class layernorm_2p5d(torch.autograd.Function): class layernorm_2p5d(torch.autograd.Function):
"""
Layernorm
:param input: input maxtrix
:type input: torch.tensor
:param E_x: mean
:type E_x: torch.tensor
:param Var_x: variance
:type Var_x: torch.tensor
:param hidden_size: hidden size
:type hidden_size: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, def forward(ctx: Any, input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
...@@ -518,6 +672,18 @@ class layernorm_2p5d(torch.autograd.Function): ...@@ -518,6 +672,18 @@ class layernorm_2p5d(torch.autograd.Function):
class all_gather_weight_2p5d(torch.autograd.Function): class all_gather_weight_2p5d(torch.autograd.Function):
"""
all gather the weight of 2.5D parallelism
:param inputs: input maxtrix
:type inputs: torch.tensor
:param dim: dimension of all gather
:type dim: int
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
...@@ -536,6 +702,14 @@ class all_gather_weight_2p5d(torch.autograd.Function): ...@@ -536,6 +702,14 @@ class all_gather_weight_2p5d(torch.autograd.Function):
class SplitFirst(torch.autograd.Function): class SplitFirst(torch.autograd.Function):
"""
:param inputs: input maxtrix
:type inputs: torch.tensor
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
...@@ -566,7 +740,14 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -566,7 +740,14 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
class reduce_by_batch_2p5d(torch.autograd.Function): class reduce_by_batch_2p5d(torch.autograd.Function):
"""All-reduce the input from the model parallel region.""" """
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
@staticmethod @staticmethod
def symbolic(graph, input_, reduce_mean: bool = False): def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
......
...@@ -21,7 +21,8 @@ from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from ...@@ -21,7 +21,8 @@ from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from
@LAYERS.register_module @LAYERS.register_module
class Linear2p5D(ParallelLayer): class Linear2p5D(ParallelLayer):
"""Linear layer for 2.5D parallelism """
Linear layer for 2.5D parallelism
:param in_features: size of each input sample :param in_features: size of each input sample
:type in_features: int :type in_features: int
...@@ -31,6 +32,10 @@ class Linear2p5D(ParallelLayer): ...@@ -31,6 +32,10 @@ class Linear2p5D(ParallelLayer):
:type bias: bool, optional :type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
...@@ -125,7 +130,8 @@ class Linear2p5D(ParallelLayer): ...@@ -125,7 +130,8 @@ class Linear2p5D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class LayerNorm2p5D(ParallelLayer): class LayerNorm2p5D(ParallelLayer):
r"""Layer Normalization for 2.5D parallelism r"""
Layer Normalization for 2.5D parallelism
:param normalized_shape: input shape from an expected input :param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]` of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
...@@ -196,17 +202,27 @@ class LayerNorm2p5D(ParallelLayer): ...@@ -196,17 +202,27 @@ class LayerNorm2p5D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class PatchEmbedding2p5D(ParallelLayer): class PatchEmbedding2p5D(ParallelLayer):
""" 2D Image to Patch Embedding """
:param img_size: iamge size 2D Image to Patch Embedding
:param img_size: image size
:type img_size: int :type img_size: int
:param patch_size: patch size :param patch_size: patch size
:type patch_size: int :type patch_size: int
:param embed_dim: dimension of embedding :param in_chans: number of channels of input image
:type embed_dim: int :type in_chans: int
:param in_chans: number of channels of input image, defaults to 3 :param embed_size: size of embedding
:type in_chans: int, optional :type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True :param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional :type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
...@@ -286,6 +302,20 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -286,6 +302,20 @@ class PatchEmbedding2p5D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Embedding2p5D(ParallelLayer): class Embedding2p5D(ParallelLayer):
"""
Embedding for 2.5D parallelism
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
...@@ -336,6 +366,24 @@ class Embedding2p5D(ParallelLayer): ...@@ -336,6 +366,24 @@ class Embedding2p5D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Classifier2p5D(ParallelLayer): class Classifier2p5D(ParallelLayer):
"""
Classifier for 2.5D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
......
...@@ -12,6 +12,28 @@ from torch.cuda.amp import custom_bwd, custom_fwd ...@@ -12,6 +12,28 @@ from torch.cuda.amp import custom_bwd, custom_fwd
class linear_3d(torch.autograd.Function): class linear_3d(torch.autograd.Function):
"""
Linear layer for 3D parallelism
:param input_: matrix of input
:type input_: torch.tensor
:param weight: matrix of weight
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param input_dim: dimension of input, defaults to 0
:type input_dim: int, optional
:param weight_dim: dimension of weight, defaults to -1
:type weight_dim: int, optional
:param output_dim: dimension of output, defaults to 0
:type output_dim: int, optional
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, def forward(ctx,
...@@ -74,6 +96,22 @@ class linear_3d(torch.autograd.Function): ...@@ -74,6 +96,22 @@ class linear_3d(torch.autograd.Function):
class classifier_3d(torch.autograd.Function): class classifier_3d(torch.autograd.Function):
"""
Classifier
:param input_: matrix of input
:type input_: torch.tensor
:param weight: matrix of weight
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
...@@ -129,6 +167,29 @@ class classifier_3d(torch.autograd.Function): ...@@ -129,6 +167,29 @@ class classifier_3d(torch.autograd.Function):
class layernorm_3d(torch.autograd.Function): class layernorm_3d(torch.autograd.Function):
"""
Layernorm
:param input_: input maxtrix
:type input_: torch.tensor
:param weight: matrix of weight
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability
:type eps: float
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
...@@ -189,6 +250,18 @@ def split_tensor_3d(input_: Tensor, ...@@ -189,6 +250,18 @@ def split_tensor_3d(input_: Tensor,
class reduce_by_batch_3d(torch.autograd.Function): class reduce_by_batch_3d(torch.autograd.Function):
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:type reduce_mean: int, optional
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, def forward(ctx,
...@@ -215,6 +288,18 @@ class reduce_by_batch_3d(torch.autograd.Function): ...@@ -215,6 +288,18 @@ class reduce_by_batch_3d(torch.autograd.Function):
class broadcast_weight_3d_from_diagonal(torch.autograd.Function): class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
"""
broadcast weight from diagonal
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: output parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
......
...@@ -24,6 +24,19 @@ from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_e ...@@ -24,6 +24,19 @@ from ._utils import get_depth_from_env, get_last_group, get_parallel_mode_from_e
@LAYERS.register_module @LAYERS.register_module
class LayerNorm3D(ParallelLayer): class LayerNorm3D(ParallelLayer):
r"""
Layer Normalization for 3D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-12
:type eps: float, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None): def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
...@@ -55,6 +68,22 @@ class LayerNorm3D(ParallelLayer): ...@@ -55,6 +68,22 @@ class LayerNorm3D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Linear3D(ParallelLayer): class Linear3D(ParallelLayer):
"""
Linear layer for 3D parallelism
:param in_features: size of each input sample
:type in_features: int
:param out_features: size of each output sample
:type out_features: int
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
...@@ -113,6 +142,24 @@ class Linear3D(ParallelLayer): ...@@ -113,6 +142,24 @@ class Linear3D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Classifier3D(ParallelLayer): class Classifier3D(ParallelLayer):
"""
Classifier for 3D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
...@@ -173,6 +220,28 @@ class Classifier3D(ParallelLayer): ...@@ -173,6 +220,28 @@ class Classifier3D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class PatchEmbedding3D(ParallelLayer): class PatchEmbedding3D(ParallelLayer):
"""
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
img_size: int, img_size: int,
patch_size: int, patch_size: int,
...@@ -256,6 +325,20 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -256,6 +325,20 @@ class PatchEmbedding3D(ParallelLayer):
@LAYERS.register_module @LAYERS.register_module
class Embedding3D(ParallelLayer): class Embedding3D(ParallelLayer):
"""
Embedding for 3D parallelism
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
......
...@@ -32,7 +32,8 @@ def drop_path(x, drop_prob: float = 0., training: bool = False): ...@@ -32,7 +32,8 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
class DropPath(nn.Module): class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
""" """
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
...@@ -97,7 +98,27 @@ class WrappedDropPath(nn.Module): ...@@ -97,7 +98,27 @@ class WrappedDropPath(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class VanillaPatchEmbedding(nn.Module): class VanillaPatchEmbedding(nn.Module):
""" 2D Image to Patch Embedding """
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
...@@ -148,6 +169,24 @@ class VanillaPatchEmbedding(nn.Module): ...@@ -148,6 +169,24 @@ class VanillaPatchEmbedding(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class VanillaClassifier(nn.Module): class VanillaClassifier(nn.Module):
"""
Classifier
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
......
...@@ -7,7 +7,8 @@ from torch.nn.modules.loss import _Loss ...@@ -7,7 +7,8 @@ from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss2D(_Loss): class CrossEntropyLoss2D(_Loss):
"""Cross entropy loss for 2D parallelism """
Cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True :param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional :type reduction: bool, optional
......
...@@ -7,7 +7,9 @@ from torch.nn.modules.loss import _Loss ...@@ -7,7 +7,9 @@ from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss2p5D(_Loss): class CrossEntropyLoss2p5D(_Loss):
"""Cross entropy loss for 2.5D parallelism """
Cross entropy loss for 2.5D parallelism
:param reduction: whether to average the loss, defaults to True :param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional :type reduction: bool, optional
""" """
......
...@@ -7,14 +7,11 @@ from torch.nn.modules.loss import _Loss ...@@ -7,14 +7,11 @@ from torch.nn.modules.loss import _Loss
@LOSSES.register_module @LOSSES.register_module
class CrossEntropyLoss3D(_Loss): class CrossEntropyLoss3D(_Loss):
"""Cross entropy loss for 3D parallelism """
Cross entropy loss for 3D parallelism
:param depth: depth for 3D parallelism :param depth: depth for 3D parallelism
:type depth: int :type depth: int
:param input_parallel_mode: parallel mode for input tensor
:type input_parallel_mode: ParallelMode
:param weight_parallel_mode: parallel mode for weight
:type weight_parallel_mode: ParallelMode
:param reduction: whether to average the loss, defaults to True :param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional :type reduction: bool, optional
""" """
......
...@@ -6,6 +6,12 @@ from ._utils import calc_acc ...@@ -6,6 +6,12 @@ from ._utils import calc_acc
class Accuracy2D(nn.Module): class Accuracy2D(nn.Module):
"""
Accuracy for 2D parallelism
:param logits: predicted labels
:param targets: true labels
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -6,6 +6,12 @@ from ._utils import calc_acc ...@@ -6,6 +6,12 @@ from ._utils import calc_acc
class Accuracy2p5D(nn.Module): class Accuracy2p5D(nn.Module):
"""
Accuracy for 2p5D parallelism
:param logits: predicted labels
:param targets: true labels
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -8,6 +8,12 @@ from ._utils import calc_acc ...@@ -8,6 +8,12 @@ from ._utils import calc_acc
class Accuracy3D(nn.Module): class Accuracy3D(nn.Module):
"""
Accuracy for 3D parallelism
:param logits: predicted labels
:param targets: true labels
"""
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment