Commit 9ee197d0 authored by アマデウス's avatar アマデウス Committed by Frank Lee
Browse files

moved env variables to global variables; (#215)

added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
parent b82d60be
from .layers import Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row from .layers import (Classifier1D, Dropout1D, Embedding1D, Linear1D, Linear1D_Col, Linear1D_Row,
from .layers import MixedFusedLayerNorm1D as LayerNorm1D VocabParallelClassifier1D, VocabParallelEmbedding1D)
__all__ = ['Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'LayerNorm1D', 'Embedding1D', 'Dropout1D'] __all__ = [
'Linear1D', 'Linear1D_Col', 'Linear1D_Row', 'Embedding1D', 'Dropout1D', 'Classifier1D', 'VocabParallelClassifier1D',
'VocabParallelEmbedding1D'
]
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import PARALLEL_INPUT_1D
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from ..utils import divide from ..utils import divide
def set_parallel_input(input_parallel: bool): def set_parallel_input(input_parallel: bool):
os.environ[PARALLEL_INPUT_1D] = 'true' if input_parallel else '' env.parallel_input_1d = input_parallel
def get_parallel_input(): def get_parallel_input():
return bool(os.environ[PARALLEL_INPUT_1D]) return env.parallel_input_1d
def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank): def vocab_range_from_per_partition_vocab_size(per_partition_vocab_size, rank):
......
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import numbers
from contextlib import nullcontext
from typing import Callable, Tuple from typing import Callable, Tuple
import torch import torch
...@@ -11,17 +9,17 @@ import torch.nn.functional as F ...@@ -11,17 +9,17 @@ import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import divide, set_tensor_parallel_attribute_by_partition from ..utils import divide, set_tensor_parallel_attribute_by_partition
from ._operation import FusedLayerNormAffineFunction1D from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad,
from ._utils import (gather_forward_split_backward, get_parallel_input, reduce_grad, reduce_input, set_parallel_input, reduce_input, set_parallel_input, split_forward_gather_backward)
split_forward_gather_backward)
@LAYERS.register_module @LAYERS.register_module
...@@ -44,6 +42,7 @@ class Linear1D(torch.nn.Module): ...@@ -44,6 +42,7 @@ class Linear1D(torch.nn.Module):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
...@@ -106,12 +105,13 @@ class Classifier1D(ParallelLayer): ...@@ -106,12 +105,13 @@ class Classifier1D(ParallelLayer):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
...@@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer): ...@@ -139,6 +139,7 @@ class Classifier1D(ParallelLayer):
self.reset_parameters(weight_initializer, bias_initializer) self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes() self._set_tensor_parallel_attributes()
set_parallel_input(False) set_parallel_input(False)
env.vocab_parallel = False
def reset_parameters(self, weight_initializer, bias_initializer) -> None: def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes fan_in, fan_out = self.in_features, self.num_classes
...@@ -167,6 +168,84 @@ class Classifier1D(ParallelLayer): ...@@ -167,6 +168,84 @@ class Classifier1D(ParallelLayer):
return output return output
@LAYERS.register_module
class VocabParallelClassifier1D(ParallelLayer):
"""ColLinear with given weight
Classifier of 1D parallelism
:param in_features: size of input features
:type in_features: int
:param num_classes: number of classes in the dataset
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.parallel_input = get_parallel_input()
# Divide the weight matrix along the last dimension.
self.num_classes_per_partition = divide(num_classes, gpc.tensor_parallel_size)
# Parameters.
# Initialize weight.
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(torch.empty(self.num_classes_per_partition, self.in_features, **factory_kwargs))
self.has_weight = True
if bias:
self.bias = Parameter(torch.empty(self.num_classes_per_partition, **factory_kwargs))
else:
self.bias = None
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = True
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR)
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, num_partition)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, num_partition)
def forward(self, input_: Tensor) -> Tensor:
# Set up backprop all-reduce.
input_parallel = reduce_grad(input_, ParallelMode.PARALLEL_1D)
# Matrix multiply.
output = F.linear(input_parallel, self.weight, self.bias)
return output
@LAYERS.register_module @LAYERS.register_module
class Linear1D_Col(ParallelLayer): class Linear1D_Col(ParallelLayer):
"""Linear layer with column parallelism. """Linear layer with column parallelism.
...@@ -324,7 +403,7 @@ class Linear1D_Row(ParallelLayer): ...@@ -324,7 +403,7 @@ class Linear1D_Row(ParallelLayer):
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None: if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in) bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D) broadcast(self.bias, gpc.get_ranks_in_group(ParallelMode.PARALLEL_1D)[0], ParallelMode.PARALLEL_1D)
def _set_tensor_parallel_attributes(self): def _set_tensor_parallel_attributes(self):
num_partition = gpc.get_world_size(ParallelMode.TENSOR) num_partition = gpc.get_world_size(ParallelMode.TENSOR)
...@@ -341,45 +420,13 @@ class Linear1D_Row(ParallelLayer): ...@@ -341,45 +420,13 @@ class Linear1D_Row(ParallelLayer):
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
if not self.skip_bias_add: if not self.skip_bias_add:
output = output + self.bias if self.bias is not None:
output = output + self.bias
return output return output
else: else:
return output, self.bias return output, self.bias
@LAYERS.register_module
class MixedFusedLayerNorm1D(torch.nn.Module):
r"""
Layer Normalization for 1D parallelism
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability, defaults to 1e-05
:type eps: float, optional
"""
def __init__(self, normalized_shape, eps=1e-5):
super(MixedFusedLayerNorm1D, self).__init__()
if isinstance(normalized_shape, numbers.Integral):
normalized_shape = (normalized_shape, )
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.reset_parameters()
def reset_parameters(self):
init.ones_(self.weight)
init.zeros_(self.bias)
def forward(self, input):
return FusedLayerNormAffineFunction1D.apply(input, self.weight, self.bias, self.normalized_shape, self.eps)
@LAYERS.register_module @LAYERS.register_module
class Embedding1D(ParallelLayer): class Embedding1D(ParallelLayer):
""" """
...@@ -398,11 +445,12 @@ class Embedding1D(ParallelLayer): ...@@ -398,11 +445,12 @@ class Embedding1D(ParallelLayer):
:param args: Args used in F.embedding :param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding :param kwargs: Kwargs used in F.embedding
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
...@@ -446,6 +494,84 @@ class Embedding1D(ParallelLayer): ...@@ -446,6 +494,84 @@ class Embedding1D(ParallelLayer):
return output return output
@LAYERS.register_module
class VocabParallelEmbedding1D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
tensor_parallel_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
self.num_embeddings_per_partition = divide(num_embeddings, tensor_parallel_size)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
set_parallel_input(False)
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, gpc.tensor_parallel_size)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
return output
@LAYERS.register_module @LAYERS.register_module
class Dropout1D(ParallelLayer): class Dropout1D(ParallelLayer):
""" """
...@@ -456,6 +582,7 @@ class Dropout1D(ParallelLayer): ...@@ -456,6 +582,7 @@ class Dropout1D(ParallelLayer):
:param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False`` :param inplace: If set to ``True``, will do this operation in-place, defaults tp ``False``
:type inplace: bool, optional :type inplace: bool, optional
""" """
def __init__(self, p: float = 0.5, inplace: bool = False): def __init__(self, p: float = 0.5, inplace: bool = False):
super().__init__() super().__init__()
self.parallel_input = get_parallel_input() self.parallel_input = get_parallel_input()
...@@ -463,7 +590,9 @@ class Dropout1D(ParallelLayer): ...@@ -463,7 +590,9 @@ class Dropout1D(ParallelLayer):
self.inplace = inplace self.inplace = inplace
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
cm = nullcontext() if not self.parallel_input else seed(ParallelMode.TENSOR) if self.parallel_input:
with cm: with seed(ParallelMode.TENSOR):
output = F.dropout(input_, self.p, self.training, self.inplace)
else:
output = F.dropout(input_, self.p, self.training, self.inplace) output = F.dropout(input_, self.p, self.training, self.inplace)
return output return output
from ._operation import reduce_by_batch_2d, split_tensor_2d from ._operation import reduce_by_batch_2d, split_tensor_2d
from .layers import Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D from .layers import (Classifier2D, Embedding2D, LayerNorm2D, Linear2D, PatchEmbedding2D, VocabParallelClassifier2D,
VocabParallelEmbedding2D)
__all__ = [ __all__ = [
'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D', 'Embedding2D' 'split_tensor_2d', 'reduce_by_batch_2d', 'Linear2D', 'LayerNorm2D', 'Classifier2D', 'PatchEmbedding2D',
'Embedding2D', 'VocabParallelEmbedding2D', 'VocabParallelClassifier2D'
] ]
...@@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc ...@@ -8,6 +8,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from colossalai.global_variables import tensor_parallel_env as env
def matmul_2d( def matmul_2d(
...@@ -22,6 +23,7 @@ def matmul_2d( ...@@ -22,6 +23,7 @@ def matmul_2d(
): ):
""" """
Matrix multiplication for 2D parallelism Matrix multiplication for 2D parallelism
:param a: matrix :math:`A` :param a: matrix :math:`A`
:type a: torch.tensor :type a: torch.tensor
:param b: matrix :math:`B` :param b: matrix :math:`B`
...@@ -56,37 +58,7 @@ def matmul_2d( ...@@ -56,37 +58,7 @@ def matmul_2d(
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
class classifier_2d(torch.autograd.Function): class _Classifier2D(torch.autograd.Function):
"""
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward( def forward(
...@@ -150,14 +122,54 @@ class classifier_2d(torch.autograd.Function): ...@@ -150,14 +122,54 @@ class classifier_2d(torch.autograd.Function):
B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A) B_grad = torch.matmul(output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), A)
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
B_grad = B_grad.reshape(ctx.B_shape) B_grad = B_grad.reshape(ctx.B_shape)
bias_grad = None
if ctx.use_bias: if ctx.use_bias:
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
else:
bias_grad = None
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
def classifier_2d(A: Tensor, B: Tensor, bias: Optional[Tensor], summa_dim: int, out_shape: Tuple[int, ...],
row_rank: int, col_rank: int, row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
"""
2D parallel classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param summa_dim: dimension of SUMMA fo 2D parallelism
:type summa_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
return _Classifier2D.apply(A, B, bias, summa_dim, out_shape, row_rank, col_rank, row_parallel_mode,
col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
tensor_parallel_size)
class Matmul_AB_2D(torch.autograd.Function): class Matmul_AB_2D(torch.autograd.Function):
""" """
Matrix multiplication for :math:`C = AB` Matrix multiplication for :math:`C = AB`
...@@ -230,9 +242,9 @@ class Matmul_AB_2D(torch.autograd.Function): ...@@ -230,9 +242,9 @@ class Matmul_AB_2D(torch.autograd.Function):
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2 opa = [None] * 2
opb = [None] * 2 opb = [None] * 2
...@@ -361,9 +373,9 @@ class Matmul_ABT_2D(torch.autograd.Function): ...@@ -361,9 +373,9 @@ class Matmul_ABT_2D(torch.autograd.Function):
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2 opb = [None] * 2
opr = [None] * 2 opr = [None] * 2
...@@ -501,9 +513,9 @@ class Matmul_ATB_2D(torch.autograd.Function): ...@@ -501,9 +513,9 @@ class Matmul_ATB_2D(torch.autograd.Function):
col_group = gpc.get_group(col_parallel_mode) col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2 opa = [None] * 2
opr = [None] * 2 opr = [None] * 2
...@@ -572,35 +584,7 @@ class Matmul_ATB_2D(torch.autograd.Function): ...@@ -572,35 +584,7 @@ class Matmul_ATB_2D(torch.autograd.Function):
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None
class add_bias_2d(torch.autograd.Function): class _Add_Bias_2D(torch.autograd.Function):
"""
Matrix add bias: :math:`C = A + b`
:param input_: matrix :math:`A`
:type input_: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: size of ouput per partition
:type output_size_per_partition: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward( def forward(
...@@ -651,31 +635,47 @@ class add_bias_2d(torch.autograd.Function): ...@@ -651,31 +635,47 @@ class add_bias_2d(torch.autograd.Function):
return output_grad, grad, None, None, None, None, None, None, None, None, None, None return output_grad, grad, None, None, None, None, None, None, None, None, None, None
class layernorm_2d(torch.autograd.Function): def add_bias_2d(input_: Tensor, bias: Tensor, output_size_per_partition: int, row_rank: int, col_rank: int,
row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, skip_bias_add: bool,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
""" """
Layernorm Matrix add bias: :math:`C = A + b`
:param input_: input maxtrix :param input_: matrix :math:`A`
:type input_: torch.tensor :type input_: torch.tensor
:param E_x: mean :param bias: matrix :math:`b`
:type E_x: torch.tensor :type bias: torch.tensor
:param Var_x: variance :param output_size_per_partition: size of ouput per partition
:type Var_x: torch.tensor :type output_size_per_partition: int
:param hidden_size: hidden size :param row_rank: the rank of row
:type hidden_size: int :type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode :param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode :param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
""" """
return _Add_Bias_2D.apply(input_, bias, output_size_per_partition, row_rank, col_rank, row_parallel_mode,
col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank,
pipeline_parallel_size, tensor_parallel_size)
class _Layernorm_2D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx: Any, def forward(ctx: Any, input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
input_: Tensor,
E_x: Tensor,
Var_x: Tensor,
hidden_size: int,
row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor: col_parallel_mode: ParallelMode) -> Tensor:
input_ = input_ - E_x input_ = input_ - E_x
# in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps) # in here, input = x - E[x], Var_x = 1 / sqrt(Var[x] + eps)
...@@ -709,76 +709,64 @@ class layernorm_2d(torch.autograd.Function): ...@@ -709,76 +709,64 @@ class layernorm_2d(torch.autograd.Function):
return input_grad, None, None, None, None, None return input_grad, None, None, None, None, None
class all_gather_weight_2d(torch.autograd.Function): def layernorm_2d(input_: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int, row_parallel_mode: ParallelMode,
col_parallel_mode: ParallelMode) -> Tensor:
""" """
all gather the weight of 2D parallelism Layernorm
:param inputs: input maxtrix :param input_: input maxtrix
:type inputs: torch.tensor :type input_: torch.tensor
:param dim: dimension of all gather :param E_x: mean
:type dim: int :type E_x: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism :param Var_x: variance
:type summa_dim: int :type Var_x: torch.tensor
:param hidden_size: hidden size
:type hidden_size: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode :param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
return _Layernorm_2D.apply(input_, E_x, Var_x, hidden_size, row_parallel_mode, col_parallel_mode)
class _AllGatherTensor2D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
ctx.dim = dim ctx.dim = dim
ctx.summa_dim = summa_dim ctx.parallel_mode = parallel_mode
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = all_gather(inputs, dim, col_parallel_mode) outputs = all_gather(inputs, dim, parallel_mode)
return outputs return outputs
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.summa_dim, dim=ctx.dim)[ctx.row_rank] grad = reduce_scatter(output_grad, ctx.dim, ctx.parallel_mode)
return grad.contiguous(), None, None, None return grad.contiguous(), None, None
class SplitFirst(torch.autograd.Function): def all_gather_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
""" """
All gather the tensor of 2D parallelism
:param inputs: input maxtrix :param inputs: input maxtrix
:type inputs: torch.tensor :type inputs: torch.tensor
:param summa_dim: dimension of SUMMA fo 2D parallelism :param dim: dimension to gather
:type summa_dim: int :type dim: int
:param col_parallel_mode: column parallel mode :param parallel_mode: parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
@staticmethod return _AllGatherTensor2D.apply(tensor, dim, parallel_mode)
@custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, summa_dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.summa_dim = summa_dim
ctx.batch_size = inputs.size(0)
ctx.para_mode = col_parallel_mode
row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = inputs.chunk(summa_dim, dim=0)[row_rank]
return outputs
@staticmethod
@custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad_shape = (ctx.batch_size, ) + output_grad.shape[1:]
grad = torch.empty(grad_shape, dtype=output_grad.dtype, device=get_current_device())
dist.all_gather(list(grad.chunk(ctx.summa_dim, dim=0)),
output_grad.contiguous(),
group=gpc.get_group(ctx.para_mode))
return grad, None, None
def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
"""Splits 2D tensor in specified dimension across cols """Splits 2D tensor in specified dimension across cols
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
:type input_: torch.Tensor :type input_: torch.Tensor
:type dim: int, optional :type dim: int, optional
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
...@@ -788,9 +776,50 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -788,9 +776,50 @@ def split_tensor_2d(input_: Tensor, dim: int = 0) -> Tensor:
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)].contiguous()
class reduce_by_batch_2d(torch.autograd.Function): class _ReduceTensor2D(torch.autograd.Function):
"""All-reduce the input from the model parallel region. @staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return output_grad, None
def reduce_tensor_2d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
""" """
All-reduce the input.
:param input_: input tensor
:param parallel_mode: parallel mode
"""
return _ReduceTensor2D.apply(input_, parallel_mode)
class _ReduceScatterTensor2D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
ctx.parallel_mode = parallel_mode
return reduce_scatter(input_, dim, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None
def reduce_scatter_tensor_2d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
Reduce-scatter the input.
:param tensor: Input tensor
:param dim: Dimension to scatter
:param parallel_mode: Parallel mode
"""
return _ReduceScatterTensor2D.apply(tensor, dim, parallel_mode)
class _ReduceByBatch2D(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_, reduce_mean: bool = False): def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
...@@ -802,12 +831,6 @@ class reduce_by_batch_2d(torch.autograd.Function): ...@@ -802,12 +831,6 @@ class reduce_by_batch_2d(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_, reduce_mean: bool = False): def forward(ctx, input_, reduce_mean: bool = False):
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2D_COL)
ctx.reduce_mean = reduce_mean ctx.reduce_mean = reduce_mean
if reduce_mean: if reduce_mean:
...@@ -823,3 +846,14 @@ class reduce_by_batch_2d(torch.autograd.Function): ...@@ -823,3 +846,14 @@ class reduce_by_batch_2d(torch.autograd.Function):
return output_grad / ctx.reduce_size, None return output_grad / ctx.reduce_size, None
else: else:
return output_grad, None return output_grad, None
def reduce_by_batch_2d(input_, reduce_mean: bool = False) -> Tensor:
"""All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: bool, optional
"""
return _ReduceByBatch2D.apply(input_, reduce_mean)
\ No newline at end of file
import os
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.process_group_initializer.initializer_2d import SUMMA_DIM
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
def get_summa_dim_from_env() -> int: def get_summa_dim_from_env() -> int:
try: try:
summa_dim = os.environ[SUMMA_DIM] summa_dim = env.summa_dim
summa_dim = int(summa_dim)
assert summa_dim > 0, 'SUMMA_DIM must be larger than zero' assert summa_dim > 0, 'SUMMA_DIM must be larger than zero'
return summa_dim return summa_dim
......
...@@ -7,15 +7,16 @@ import torch.nn.functional as F ...@@ -7,15 +7,16 @@ import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ._operation import Matmul_AB_2D, add_bias_2d, all_gather_weight_2d, classifier_2d, layernorm_2d from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import *
from ._utils import assert_summa_initialization, get_summa_dim_from_env from ._utils import assert_summa_initialization, get_summa_dim_from_env
...@@ -43,7 +44,7 @@ class Linear2D(ParallelLayer): ...@@ -43,7 +44,7 @@ class Linear2D(ParallelLayer):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype=None, dtype: torch.dtype = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
...@@ -101,16 +102,16 @@ class Linear2D(ParallelLayer): ...@@ -101,16 +102,16 @@ class Linear2D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
if self.skip_bias_add: if self.skip_bias_add:
bias = add_bias_2d.apply(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank, bias = add_bias_2d(None, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.pipeline_parallel_size, self.tensor_parallel_size) self.tensor_parallel_size)
return output, bias return output, bias
else: else:
output = add_bias_2d.apply(output, self.bias, self.hidden_size_per_partition, self.row_rank, output = add_bias_2d(output, self.bias, self.hidden_size_per_partition, self.row_rank, self.col_rank,
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False,
False, self.data_parallel_rank, self.pipeline_parallel_rank, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.pipeline_parallel_size, self.tensor_parallel_size) self.tensor_parallel_size)
return output return output
else: else:
return output return output
...@@ -174,16 +175,14 @@ class LayerNorm2D(ParallelLayer): ...@@ -174,16 +175,14 @@ class LayerNorm2D(ParallelLayer):
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = layernorm_2d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW, output = layernorm_2d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL) ParallelMode.PARALLEL_2D_COL)
bias = add_bias_2d.apply(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank, bias = add_bias_2d(None, self.beta, self.partitioned_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
self.tensor_parallel_size) scale = add_bias_2d(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank,
scale = add_bias_2d.apply(None, self.gamma, self.partitioned_partition, self.row_rank, self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.data_parallel_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, True, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output) output = torch.addcmul(bias, scale, output)
return output return output
...@@ -217,8 +216,8 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -217,8 +216,8 @@ class PatchEmbedding2D(ParallelLayer):
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
...@@ -268,19 +267,21 @@ class PatchEmbedding2D(ParallelLayer): ...@@ -268,19 +267,21 @@ class PatchEmbedding2D(ParallelLayer):
position_embed_initializer(self.pos_embed) position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2d(input_)
B, C, H, W = input_.shape B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
weight = all_gather_weight_2d.apply(self.weight, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) weight = all_gather_tensor_2d(self.weight, 0, ParallelMode.PARALLEL_2D_COL)
bias = all_gather_weight_2d.apply(self.bias, 0, self.summa_dim, ParallelMode.PARALLEL_2D_COL) bias = all_gather_tensor_2d(self.bias, 0, ParallelMode.PARALLEL_2D_COL)
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_weight_2d.apply(self.cls_token, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) cls_token = all_gather_tensor_2d(self.cls_token, -1, ParallelMode.PARALLEL_2D_COL)
pos_embed = all_gather_weight_2d.apply(self.pos_embed, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) pos_embed = all_gather_tensor_2d(self.pos_embed, -1, ParallelMode.PARALLEL_2D_COL)
cls_token = cls_token.expand(output.shape[0], -1, -1) cls_token = cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)
output = output + pos_embed output = output + pos_embed
...@@ -310,7 +311,7 @@ class Embedding2D(ParallelLayer): ...@@ -310,7 +311,7 @@ class Embedding2D(ParallelLayer):
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
...@@ -347,13 +348,90 @@ class Embedding2D(ParallelLayer): ...@@ -347,13 +348,90 @@ class Embedding2D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = all_gather_weight_2d.apply(self.weight, -1, self.summa_dim, ParallelMode.PARALLEL_2D_COL) input_ = split_tensor_2d(input_)
weight = all_gather_tensor_2d(self.weight, -1, ParallelMode.PARALLEL_2D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output return output
@LAYERS.register_module
class VocabParallelEmbedding2D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.num_embeddings_per_partition = divide(self.num_embeddings, self.summa_dim)
self.embed_dim_per_partition = divide(self.embed_dim, self.summa_dim)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
output_parallel[input_mask, :] = 0.
output = reduce_scatter_tensor_2d(output_parallel, 0, ParallelMode.PARALLEL_2D_COL)
return output
@LAYERS.register_module @LAYERS.register_module
class Classifier2D(ParallelLayer): class Classifier2D(ParallelLayer):
""" """
...@@ -379,7 +457,7 @@ class Classifier2D(ParallelLayer): ...@@ -379,7 +457,7 @@ class Classifier2D(ParallelLayer):
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
...@@ -429,7 +507,101 @@ class Classifier2D(ParallelLayer): ...@@ -429,7 +507,101 @@ class Classifier2D(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, ) out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2d.apply(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, return classifier_2d(input_, self.weight, self.bias, self.summa_dim, out_shape, self.row_rank, self.col_rank,
self.col_rank, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, self.data_parallel_rank,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.tensor_parallel_size)
self.tensor_parallel_size)
@LAYERS.register_module
class VocabParallelClassifier2D(ParallelLayer):
"""
Vocab parallel classifier layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
# parallel setting
assert_summa_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
self.summa_dim = get_summa_dim_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.summa_dim)
self.output_size_per_partition = divide(num_classes, self.summa_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.output_size_per_partition, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(divide(self.num_classes, self.summa_dim**2), **factory_kwargs))
else:
self.bias = None
# initialize parameters
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.summa_dim**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.summa_dim**2)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, x: Tensor) -> Tensor:
# input: [m/q, n/q, k/q]
# output: [m/q, n/q, h/q]
out_shape = x.shape[:-1] + (self.output_size_per_partition, )
output = Matmul_ABT_2D.apply(x, self.weight, self.summa_dim, out_shape, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
if self.bias is not None:
output = add_bias_2d(output, self.bias, self.output_size_per_partition, self.row_rank, self.col_rank,
ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, False,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
return output
from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d from ._operation import reduce_by_batch_2p5d, split_tensor_2p5d
from .layers import Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D from .layers import (Classifier2p5D, Embedding2p5D, LayerNorm2p5D, Linear2p5D, PatchEmbedding2p5D,
VocabParallelClassifier2p5D, VocabParallelEmbedding2p5D)
__all__ = [ __all__ = [
'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D', 'split_tensor_2p5d', 'reduce_by_batch_2p5d', 'Linear2p5D', 'LayerNorm2p5D', 'Classifier2p5D', 'PatchEmbedding2p5D',
'Embedding2p5D' 'Embedding2p5D', 'VocabParallelClassifier2p5D', 'VocabParallelEmbedding2p5D'
] ]
...@@ -22,42 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode): ...@@ -22,42 +22,7 @@ def get_parallel_rank(parallel_mode: ParallelMode):
return gpc.get_local_rank(parallel_mode) return gpc.get_local_rank(parallel_mode)
def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: class _Classifier2p5D(torch.autograd.Function):
return torch.chunk(input_, gpc.get_world_size(ParallelMode.PARALLEL_2P5D_COL),
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
class classifier_2p5d(torch.autograd.Function):
"""
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward( def forward(
...@@ -122,12 +87,54 @@ class classifier_2p5d(torch.autograd.Function): ...@@ -122,12 +87,54 @@ class classifier_2p5d(torch.autograd.Function):
B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode) B_grad = reduce_scatter(B_grad, -1, ctx.col_parallel_mode)
B_grad = B_grad.reshape(ctx.B_shape) B_grad = B_grad.reshape(ctx.B_shape)
bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1))) if ctx.use_bias:
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode) bias_grad = torch.sum(output_grad, dim=tuple(range(output_grad.ndim - 1)))
bias_grad = all_reduce(bias_grad, ctx.col_parallel_mode)
else:
bias_grad = None
return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, bias_grad, None, None, None, None, None, None, None, None, None, None
def classifier_2p5d(A: Tensor, B: Tensor, bias, tesseract_dim: int, out_shape: Tuple[int,
...], row_rank: int, col_rank: int,
row_parallel_mode: ParallelMode, col_parallel_mode: ParallelMode, data_parallel_rank: int,
pipeline_parallel_rank: int, pipeline_parallel_size: int, tensor_parallel_size: int) -> Tensor:
"""
Classifier
:param a: matrix :math:`A`
:type a: torch.tensor
:param b: matrix :math:`B`
:type b: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param out_shape: shape of output tensor
:type out_shape: tuple
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
return _Classifier2p5D.apply(A, B, bias, tesseract_dim, out_shape, row_rank, col_rank, row_parallel_mode,
col_parallel_mode, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size,
tensor_parallel_size)
class Matmul_AB_2p5D(torch.autograd.Function): class Matmul_AB_2p5D(torch.autograd.Function):
""" """
Matrix multiplication for :math:`C = AB` Matrix multiplication for :math:`C = AB`
...@@ -522,37 +529,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function): ...@@ -522,37 +529,7 @@ class Matmul_ATB_2p5D(torch.autograd.Function):
return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None return A_grad, B_grad, None, None, None, None, None, None, None, None, None, None, None, None, None
class Add_Bias_2p5D(torch.autograd.Function): class _Add_Bias_2p5D(torch.autograd.Function):
"""
Matrix add bias: :math:`C = A + b`
:param input: matrix :math:`A`
:type input: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: output size in each partition
:type output_size_per_partition: int
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, def forward(ctx: Any, input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int,
...@@ -621,7 +598,46 @@ class Add_Bias_2p5D(torch.autograd.Function): ...@@ -621,7 +598,46 @@ class Add_Bias_2p5D(torch.autograd.Function):
return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None return output_grad, reduce_tmp, None, None, None, None, None, None, None, None, None, None, None, None, None, None, None
class layernorm_2p5d(torch.autograd.Function): def add_bias_2p5d(input: Tensor, bias: Tensor, output_size_per_partition: int, tesseract_dim: int, row_rank: int,
col_rank: int, dep_rank: int, col_parallel_mode: ParallelMode, skip_bias_add: bool,
data_parallel_rank: int, pipeline_parallel_rank: int, pipeline_parallel_size: int,
tensor_parallel_size: int) -> Tensor:
"""
Matrix add bias: :math:`C = A + b`
:param input: matrix :math:`A`
:type input: torch.tensor
:param bias: matrix :math:`b`
:type bias: torch.tensor
:param output_size_per_partition: output size in each partition
:type output_size_per_partition: int
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param row_rank: the rank of row
:type row_rank: int
:param col_rank: the rank of column
:type col_rank: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param skip_bias_add: If set to ``True``, it will skip bias add for linear layer, which is preserved for kernel fusion
:type skip_bias_add: bool
:param data_parallel_rank: data parallel rank
:type data_parallel_rank: int
:param pipeline_parallel_rank: pipeline parallel rank
:type pipeline_parallel_rank: int
:param pipeline_parallel_size: pipeline parallel size
:type pipeline_parallel_size: int
:param tensor_parallel_size: tensor parallel size
:type tensor_parallel_size: int
"""
return _Add_Bias_2p5D.apply(input, bias, output_size_per_partition, tesseract_dim, row_rank, col_rank, dep_rank,
col_parallel_mode, skip_bias_add, data_parallel_rank, pipeline_parallel_rank,
pipeline_parallel_size, tensor_parallel_size)
class _Layernorm2p5D(torch.autograd.Function):
""" """
Layernorm Layernorm
...@@ -671,25 +687,31 @@ class layernorm_2p5d(torch.autograd.Function): ...@@ -671,25 +687,31 @@ class layernorm_2p5d(torch.autograd.Function):
return input_grad, None, None, None, None, None, None return input_grad, None, None, None, None, None, None
class all_gather_weight_2p5d(torch.autograd.Function): def layernorm_2p5d(input: Tensor, E_x: Tensor, Var_x: Tensor, hidden_size: int,
row_parallel_mode: ParallelMode) -> Tensor:
""" """
all gather the weight of 2.5D parallelism Layernorm
:param inputs: input maxtrix :param input: input maxtrix
:type inputs: torch.tensor :type input: torch.tensor
:param dim: dimension of all gather :param E_x: mean
:type dim: int :type E_x: torch.tensor
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism :param Var_x: variance
:type tesseract_dim: int :type Var_x: torch.tensor
:param col_parallel_mode: column parallel mode :param hidden_size: hidden size
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type hidden_size: int
:param row_parallel_mode: row parallel mode
:type row_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
return _Layernorm2p5D.apply(input, E_x, Var_x, hidden_size, row_parallel_mode)
class _AllGatherTensor2p5D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx: Any, inputs: Tensor, dim: int, tesseract_dim: int, col_parallel_mode: ParallelMode) -> Tensor: def forward(ctx: Any, inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
ctx.dim = dim ctx.dim = dim
ctx.tesseract_dim = tesseract_dim ctx.col_parallel_mode = col_parallel_mode
ctx.row_rank = gpc.get_local_rank(col_parallel_mode)
outputs = all_gather(inputs, dim, col_parallel_mode) outputs = all_gather(inputs, dim, col_parallel_mode)
return outputs return outputs
...@@ -697,8 +719,24 @@ class all_gather_weight_2p5d(torch.autograd.Function): ...@@ -697,8 +719,24 @@ class all_gather_weight_2p5d(torch.autograd.Function):
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]: def backward(ctx: Any, output_grad: Tensor) -> Tuple[Tensor, ...]:
grad = output_grad.chunk(ctx.tesseract_dim, dim=ctx.dim)[ctx.row_rank] grad = reduce_scatter(output_grad, ctx.dim, ctx.col_parallel_mode)
return grad.contiguous(), None, None, None return grad.contiguous(), None, None
def all_gather_tensor_2p5d(inputs: Tensor, dim: int, col_parallel_mode: ParallelMode) -> Tensor:
"""
all gather the weight of 2.5D parallelism
:param inputs: input maxtrix
:type inputs: torch.tensor
:param dim: dimension of all gather
:type dim: int
:param tesseract_dim: dimension of TESSERACT fo 2.5D parallelism
:type tesseract_dim: int
:param col_parallel_mode: column parallel mode
:type col_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
return _AllGatherTensor2p5D.apply(inputs, dim, col_parallel_mode)
class SplitFirst(torch.autograd.Function): class SplitFirst(torch.autograd.Function):
...@@ -737,10 +775,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -737,10 +775,10 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
:type input_: torch.Tensor :type input_: torch.Tensor
:type dim: int, optional :type dim: int, optional
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
...@@ -750,9 +788,49 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor: ...@@ -750,9 +788,49 @@ def split_tensor_2p5d(input_: Tensor, dim: int = 0) -> Tensor:
dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous() dim=dim)[gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)].contiguous()
class reduce_by_batch_2p5d(torch.autograd.Function): class _ReduceTensor2p5D(torch.autograd.Function):
"""All-reduce the input from the model parallel region. @staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return output_grad, None
def reduce_tensor_2p5d(input_: Tensor, parallel_mode: ParallelMode) -> Tensor:
"""
All-reduce the input.
:param input_: input tensor
:param parallel_mode: parallel mode
"""
return _ReduceTensor2p5D.apply(input_, parallel_mode)
class _ReduceScatterTensor2p5D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
ctx.parallel_mode = parallel_mode
return reduce_scatter(input_, dim, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return all_gather(output_grad, ctx.dim, ctx.parallel_mode), None, None
def reduce_scatter_tensor_2p5d(input_: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""
Reduce-scatter the input.
:param input_: input tensor
:param parallel_mode: parallel mode
""" """
return _ReduceScatterTensor2p5D.apply(input_, dim, parallel_mode)
class _RreduceByBatch2p5D(torch.autograd.Function):
@staticmethod @staticmethod
def symbolic(graph, input_, reduce_mean: bool = False): def symbolic(graph, input_, reduce_mean: bool = False):
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
...@@ -764,12 +842,6 @@ class reduce_by_batch_2p5d(torch.autograd.Function): ...@@ -764,12 +842,6 @@ class reduce_by_batch_2p5d(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_, reduce_mean: bool = False): def forward(ctx, input_, reduce_mean: bool = False):
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL) output = all_reduce(input_, ParallelMode.PARALLEL_2P5D_COL)
ctx.reduce_mean = reduce_mean ctx.reduce_mean = reduce_mean
if reduce_mean: if reduce_mean:
...@@ -785,3 +857,15 @@ class reduce_by_batch_2p5d(torch.autograd.Function): ...@@ -785,3 +857,15 @@ class reduce_by_batch_2p5d(torch.autograd.Function):
return output_grad / ctx.reduce_size, None return output_grad / ctx.reduce_size, None
else: else:
return output_grad, None return output_grad, None
def reduce_by_batch_2p5d(input_, reduce_mean: bool = False) -> Tensor:
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param reduce_mean: If set to ``True``, it will divide the output by column parallel size, default to False
:type reduce_mean: bool, optional
"""
return _RreduceByBatch2p5D.apply(input_, reduce_mean)
\ No newline at end of file
import os
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
def get_tesseract_dim_dep_from_env(): def get_tesseract_dim_dep_from_env():
try: try:
tesseract_dim = int(os.environ['TESSERACT_DIM']) tesseract_dim = env.tesseract_dim
tesseract_dep = int(os.environ['TESSERACT_DEP']) tesseract_dep = env.tesseract_dep
assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero' assert tesseract_dim > 0, 'TESSERACT_DIM must be larger than zero'
assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero' assert tesseract_dep > 0, 'TESSERACT_DEP must be larger than zero'
return tesseract_dim, tesseract_dep return tesseract_dim, tesseract_dep
......
...@@ -7,16 +7,18 @@ import torch.nn.functional as F ...@@ -7,16 +7,18 @@ import torch.nn.functional as F
from colossalai.communication import broadcast from colossalai.communication import broadcast
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from ..base_layer import ParallelLayer from ..base_layer import ParallelLayer
from ..utils import (divide, set_tensor_parallel_attribute_by_partition, to_2tuple) from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
from ._operation import (Add_Bias_2p5D, Matmul_AB_2p5D, all_gather_weight_2p5d, classifier_2p5d, layernorm_2p5d) from ._operation import (add_bias_2p5d, Matmul_AB_2p5D, Matmul_ABT_2p5D, all_gather_tensor_2p5d, classifier_2p5d,
from ._utils import (assert_tesseract_initialization, get_tesseract_dim_dep_from_env) layernorm_2p5d, reduce_scatter_tensor_2p5d, split_tensor_2p5d)
from ._utils import assert_tesseract_initialization, get_tesseract_dim_dep_from_env
@LAYERS.register_module @LAYERS.register_module
...@@ -41,7 +43,7 @@ class Linear2p5D(ParallelLayer): ...@@ -41,7 +43,7 @@ class Linear2p5D(ParallelLayer):
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
skip_bias_add: bool = False, skip_bias_add: bool = False,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
...@@ -112,17 +114,16 @@ class Linear2p5D(ParallelLayer): ...@@ -112,17 +114,16 @@ class Linear2p5D(ParallelLayer):
if self.bias is not None: if self.bias is not None:
if self.skip_bias_add: if self.skip_bias_add:
bias = Add_Bias_2p5D.apply(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, bias = add_bias_2p5d(None, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank,
self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
True, self.data_parallel_rank, self.pipeline_parallel_rank, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.pipeline_parallel_size, self.tensor_parallel_size) self.tensor_parallel_size)
return output, bias return output, bias
else: else:
output = Add_Bias_2p5D.apply(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim,
self.row_rank, self.col_rank, self.dep_rank, self.row_rank, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL,
ParallelMode.PARALLEL_2P5D_COL, False, self.data_parallel_rank, False, self.data_parallel_rank, self.pipeline_parallel_rank,
self.pipeline_parallel_rank, self.pipeline_parallel_size, self.pipeline_parallel_size, self.tensor_parallel_size)
self.tensor_parallel_size)
return output return output
else: else:
return output return output
...@@ -187,15 +188,15 @@ class LayerNorm2p5D(ParallelLayer): ...@@ -187,15 +188,15 @@ class LayerNorm2p5D(ParallelLayer):
# this time 1/sqrt(Var_x + epsilon) # this time 1/sqrt(Var_x + epsilon)
Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon) Var_x = 1.0 / torch.sqrt(Var_x + self.variance_epsilon)
output = layernorm_2p5d.apply(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW) output = layernorm_2p5d(x, E_x, Var_x, self.normalized_shape, ParallelMode.PARALLEL_2P5D_ROW)
bias = Add_Bias_2p5D.apply(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank, bias = add_bias_2p5d(None, self.beta, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size) self.tensor_parallel_size)
scale = Add_Bias_2p5D.apply(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank, scale = add_bias_2p5d(None, self.gamma, self.partitioned_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True, self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, True,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size) self.tensor_parallel_size)
output = torch.addcmul(bias, scale, output) output = torch.addcmul(bias, scale, output)
return output return output
...@@ -229,8 +230,8 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -229,8 +230,8 @@ class PatchEmbedding2p5D(ParallelLayer):
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
...@@ -280,19 +281,21 @@ class PatchEmbedding2p5D(ParallelLayer): ...@@ -280,19 +281,21 @@ class PatchEmbedding2p5D(ParallelLayer):
position_embed_initializer(self.pos_embed) position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_2p5d(input_, 0)
B, C, H, W = input_.shape B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \ assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
weight = all_gather_weight_2p5d.apply(self.weight, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) weight = all_gather_tensor_2p5d(self.weight, 0, ParallelMode.PARALLEL_2P5D_COL)
bias = all_gather_weight_2p5d.apply(self.bias, 0, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) bias = all_gather_tensor_2p5d(self.bias, 0, ParallelMode.PARALLEL_2P5D_COL)
output = F.conv2d(input_, weight, bias, stride=self.patch_size) output = F.conv2d(input_, weight, bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = all_gather_weight_2p5d.apply(self.cls_token, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) cls_token = all_gather_tensor_2p5d(self.cls_token, -1, ParallelMode.PARALLEL_2P5D_COL)
pos_embed = all_gather_weight_2p5d.apply(self.pos_embed, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) pos_embed = all_gather_tensor_2p5d(self.pos_embed, -1, ParallelMode.PARALLEL_2P5D_COL)
cls_token = cls_token.expand(output.shape[0], -1, -1) cls_token = cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1) output = torch.cat((cls_token, output), dim=1)
output = output + pos_embed output = output + pos_embed
...@@ -322,7 +325,7 @@ class Embedding2p5D(ParallelLayer): ...@@ -322,7 +325,7 @@ class Embedding2p5D(ParallelLayer):
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
...@@ -359,13 +362,95 @@ class Embedding2p5D(ParallelLayer): ...@@ -359,13 +362,95 @@ class Embedding2p5D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = all_gather_weight_2p5d.apply(self.weight, -1, self.tesseract_dim, ParallelMode.PARALLEL_2P5D_COL) input_ = split_tensor_2p5d(input_, 0)
weight = all_gather_tensor_2p5d(self.weight, -1, ParallelMode.PARALLEL_2P5D_COL)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output return output
@LAYERS.register_module
class VocabParallelEmbedding2p5D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
assert_tesseract_initialization()
self.tesseract_dim, self.tesseract_dep = get_tesseract_dim_dep_from_env()
self.num_embeddings_per_partition = divide(self.num_embeddings, self.tesseract_dim)
self.embed_dim_per_partition = divide(self.embed_dim, self.tesseract_dim)
tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.vocab_start_index = tensor_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
# Build the mask.
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
# Mask the input.
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
output_parallel = F.embedding(masked_input, self.weight, self.padding_idx, *self.embed_args,
**self.embed_kwargs)
# Mask the output embedding.
output_parallel[input_mask, :] = 0.
# Reduce across all the model parallel GPUs.
output = reduce_scatter_tensor_2p5d(output_parallel, 0, ParallelMode.PARALLEL_2P5D_COL)
return output
@LAYERS.register_module @LAYERS.register_module
class Classifier2p5D(ParallelLayer): class Classifier2p5D(ParallelLayer):
""" """
...@@ -391,7 +476,7 @@ class Classifier2p5D(ParallelLayer): ...@@ -391,7 +476,7 @@ class Classifier2p5D(ParallelLayer):
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
...@@ -442,7 +527,114 @@ class Classifier2p5D(ParallelLayer): ...@@ -442,7 +527,114 @@ class Classifier2p5D(ParallelLayer):
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
out_shape = input_.shape[:-1] + (self.num_classes, ) out_shape = input_.shape[:-1] + (self.num_classes, )
return classifier_2p5d.apply(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank, return classifier_2p5d(input_, self.weight, self.bias, self.tesseract_dim, out_shape, self.row_rank,
self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, self.col_rank, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size, self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size) self.tensor_parallel_size)
@LAYERS.register_module
class VocabParallelClassifier2p5D(ParallelLayer):
"""
Vocab parallel classifier layer for 2.5D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
# parallel setting
assert_tesseract_initialization()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_COL)
self.col_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_ROW)
self.dep_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2P5D_DEP)
self.tesseract_dim, _ = get_tesseract_dim_dep_from_env()
# partitioning dimension
self.input_size_per_partition = divide(in_features, self.tesseract_dim)
self.hidden_size_per_partition = divide(num_classes, self.tesseract_dim)
# create weight, shape: [k/q, h/q]
factory_kwargs = {'device': get_current_device(), 'dtype': dtype}
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.hidden_size_per_partition, self.input_size_per_partition, **factory_kwargs))
self.has_weight = True
# create bias, shape: [h/q]
if bias:
self.bias = Parameter(torch.empty(self.hidden_size_per_partition, **factory_kwargs))
else:
self.bias = None
# initialize parameters
with seed(ParallelMode.TENSOR):
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.tesseract_dim**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.tesseract_dim)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, x: Tensor) -> Tensor:
# input: [m/dq, n/q, k/q]
# output: [m/dq, n/q, h/q]
out_shape = x.shape[:-1] + (self.hidden_size_per_partition, )
output = Matmul_ABT_2p5D.apply(
x,
self.weight,
self.tesseract_dim,
out_shape,
self.row_rank,
self.col_rank,
self.dep_rank,
ParallelMode.PARALLEL_2P5D_ROW,
ParallelMode.PARALLEL_2P5D_COL,
self.data_parallel_rank,
self.pipeline_parallel_rank,
self.pipeline_parallel_size,
self.tensor_parallel_size,
)
if self.bias is not None:
output = add_bias_2p5d(output, self.bias, self.hidden_size_per_partition, self.tesseract_dim, self.row_rank,
self.col_rank, self.dep_rank, ParallelMode.PARALLEL_2P5D_COL, False,
self.data_parallel_rank, self.pipeline_parallel_rank, self.pipeline_parallel_size,
self.tensor_parallel_size)
return output
from ._operation import reduce_by_batch_3d, split_tensor_3d from ._operation import reduce_by_batch_3d, split_batch_3d, split_tensor_3d
from .layers import Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D from .layers import (Classifier3D, Embedding3D, LayerNorm3D, Linear3D, PatchEmbedding3D, VocabParallelClassifier3D,
VocabParallelEmbedding3D)
__all__ = [ __all__ = [
'reduce_by_batch_3d', 'split_tensor_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D', 'Classifier3D', 'Embedding3D' 'reduce_by_batch_3d', 'split_tensor_3d', 'split_batch_3d', 'Linear3D', 'LayerNorm3D', 'PatchEmbedding3D',
'Classifier3D', 'Embedding3D', 'VocabParallelEmbedding3D', 'VocabParallelClassifier3D'
] ]
...@@ -4,36 +4,20 @@ ...@@ -4,36 +4,20 @@
from typing import Optional, Tuple from typing import Optional, Tuple
import torch import torch
from colossalai.communication import all_gather, all_reduce, reduce_scatter, broadcast, reduce from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
from colossalai.context import parallel_mode
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from torch import Tensor from torch import Tensor
from torch.cuda.amp import custom_bwd, custom_fwd from torch.cuda.amp import custom_bwd, custom_fwd
from ._utils import get_parallel_mode_from_env
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.nn.layer.base_layer import ParallelLayer
class linear_3d(torch.autograd.Function):
"""
Linear layer for 3D parallelism
:param input_: matrix of input class _Linear3D(torch.autograd.Function):
:type input_: torch.tensor
:param weight: matrix of weight
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor, optional
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param input_dim: dimension of input, defaults to 0
:type input_dim: int, optional
:param weight_dim: dimension of weight, defaults to -1
:type weight_dim: int, optional
:param output_dim: dimension of output, defaults to 0
:type output_dim: int, optional
"""
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, def forward(ctx,
...@@ -87,6 +71,8 @@ class linear_3d(torch.autograd.Function): ...@@ -87,6 +71,8 @@ class linear_3d(torch.autograd.Function):
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1])) bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op) async_ops.append(op)
else:
bias_grad = None
for op in async_ops: for op in async_ops:
if op is not None: if op is not None:
...@@ -95,9 +81,17 @@ class linear_3d(torch.autograd.Function): ...@@ -95,9 +81,17 @@ class linear_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class classifier_3d(torch.autograd.Function): def linear_3d(input_: Tensor,
weight: Tensor,
bias: Optional[Tensor],
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
output_parallel_mode: ParallelMode,
input_dim: int = 0,
weight_dim: int = -1,
output_dim: int = 0) -> Tensor:
""" """
Classifier Linear layer for 3D parallelism
:param input_: matrix of input :param input_: matrix of input
:type input_: torch.tensor :type input_: torch.tensor
...@@ -111,7 +105,19 @@ class classifier_3d(torch.autograd.Function): ...@@ -111,7 +105,19 @@ class classifier_3d(torch.autograd.Function):
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode :param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param input_dim: dimension of input, defaults to 0
:type input_dim: int, optional
:param weight_dim: dimension of weight, defaults to -1
:type weight_dim: int, optional
:param output_dim: dimension of output, defaults to 0
:type output_dim: int, optional
""" """
return _Linear3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode,
input_dim, weight_dim, output_dim)
class _Classifier3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode, def forward(ctx, input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
...@@ -156,6 +162,8 @@ class classifier_3d(torch.autograd.Function): ...@@ -156,6 +162,8 @@ class classifier_3d(torch.autograd.Function):
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode) bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True) bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
async_ops.append(op) async_ops.append(op)
else:
bias_grad = None
input_grad = torch.matmul(output_grad, weight) input_grad = torch.matmul(output_grad, weight)
...@@ -166,23 +174,17 @@ class classifier_3d(torch.autograd.Function): ...@@ -166,23 +174,17 @@ class classifier_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None, None
class layernorm_3d(torch.autograd.Function): def classifier_3d(input_: Tensor, weight: Tensor, bias: Optional[Tensor], input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
""" """
Layernorm 3D parallel classifier
:param input_: input maxtrix :param input_: matrix of input
:type input_: torch.tensor :type input_: torch.tensor
:param weight: matrix of weight :param weight: matrix of weight
:type weight: torch.tensor :type weight: torch.tensor
:param bias: matrix of bias :param bias: matrix of bias
:type bias: torch.tensor :type bias: torch.tensor, optional
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability
:type eps: float
:param input_parallel_mode: input parallel mode :param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode :param weight_parallel_mode: weight parallel mode
...@@ -190,6 +192,11 @@ class layernorm_3d(torch.autograd.Function): ...@@ -190,6 +192,11 @@ class layernorm_3d(torch.autograd.Function):
:param output_parallel_mode: output parallel mode :param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
return _Classifier3D.apply(input_, weight, bias, input_parallel_mode, weight_parallel_mode, output_parallel_mode)
class _Layernorm3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float, def forward(ctx, input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
...@@ -236,27 +243,78 @@ class layernorm_3d(torch.autograd.Function): ...@@ -236,27 +243,78 @@ class layernorm_3d(torch.autograd.Function):
return input_grad, weight_grad, bias_grad, None, None, None, None, None return input_grad, weight_grad, bias_grad, None, None, None, None, None
def split_tensor_3d(input_: Tensor, def layernorm_3d(input_: Tensor, weight: Tensor, bias: Tensor, normalized_shape: int, eps: float,
dim: int = 0, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT, output_parallel_mode: ParallelMode) -> Tensor:
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor: """
"""Splits 3D tensor in specified dimension 3D parallel Layernorm
:param input_: input maxtrix
:type input_: torch.tensor
:param weight: matrix of weight
:type weight: torch.tensor
:param bias: matrix of bias
:type bias: torch.tensor
:param normalized_shape: input shape from an expected input
of size. :math:`[* \times \text{normalized_shape}[0] \times \text{normalized_shape}[1] \times \ldots \times \text{normalized_shape}[-1]]`
If a single integer is used, it is treated as a singleton list, and this module will
normalize over the last dimension which is expected to be of that specific size.
:type normalized_shape: int
:param eps: a value added to the denominator for numerical stability
:type eps: float
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param output_parallel_mode: output parallel mode
:type output_parallel_mode: colossalai.context.parallel_mode.ParallelMode
"""
return _Layernorm3D.apply(input_, weight, bias, normalized_shape, eps, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
def split_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
"""Splits 3D parallel tensor in specified dimension
:param tensor: Input tensor
:param dim: Specified dimension in which to split
:param parallel_mode: Parallel mode
:param weight_parallel_mode: Weight parallel mode
:type tensor: torch.Tensor
:type dim: int
:type parallel_mode: colossalai.context.parallel_mode.ParallelMode
:return output: Splitted tensor
:rtype output: torch.Tensor
"""
if tensor.size(dim) <= 1:
return tensor
output = torch.chunk(tensor, gpc.get_world_size(parallel_mode),
dim=dim)[gpc.get_local_rank(parallel_mode)].contiguous()
return output
def split_batch_3d(input_: Tensor,
dim: int = 0,
input_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_INPUT,
weight_parallel_mode: ParallelMode = ParallelMode.PARALLEL_3D_WEIGHT) -> Tensor:
"""Splits 3D tensor in batch
:param input_: Input tensor :param input_: Input tensor
:param dim: Specified dimension in which to split :param dim: Specified dimension in which to split
:param input_parallel_mode: Input parallel mode :param input_parallel_mode: Input parallel mode
:param weight_parallel_mode: Weight parallel mode :param weight_parallel_mode: Weight parallel mode
:type input_: torch.Tensor :type input_: torch.Tensor
:type dim: int, optional :type dim: int, optional
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional :type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode, optional
:return output: Splitted tensor :return output: Splitted tensor
:rtype output: torch.Tensor :rtype output: torch.Tensor
""" """
if input_.size(dim) <= 1: if input_.size(dim) <= 1:
return input_ return input_
weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode), output = torch.chunk(input_, gpc.get_world_size(weight_parallel_mode),
dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous() dim=dim)[gpc.get_local_rank(weight_parallel_mode)].contiguous()
output = torch.chunk(output, gpc.get_world_size(input_parallel_mode), output = torch.chunk(output, gpc.get_world_size(input_parallel_mode),
...@@ -264,9 +322,77 @@ def split_tensor_3d(input_: Tensor, ...@@ -264,9 +322,77 @@ def split_tensor_3d(input_: Tensor,
return output return output
class reduce_by_batch_3d(torch.autograd.Function): class _ReduceTensor3D(torch.autograd.Function):
"""All-reduce the input from the model parallel region.
@staticmethod
def forward(ctx, input_, parallel_mode):
return all_reduce(input_, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
return output_grad, None
def reduce_tensor_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
"""
All-reduce the input.
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
return _ReduceTensor3D.apply(tensor, parallel_mode)
class _ReduceGrad3D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, parallel_mode):
ctx.parallel_mode = parallel_mode
return input_
@staticmethod
def backward(ctx, output_grad):
input_grad = all_reduce(output_grad, ctx.parallel_mode)
return input_grad, None
def reduce_grad_3d(tensor: Tensor, parallel_mode: ParallelMode) -> Tensor:
"""
All-reduce the gradient in backward pass.
:param tensor: Input tensor
:param parallel_mode: Parallel mode
"""
return _ReduceGrad3D.apply(tensor, parallel_mode)
class _ReduceScatterTensor3D(torch.autograd.Function):
@staticmethod
def forward(ctx, input_, dim, parallel_mode):
ctx.dim = dim
ctx.parallel_mode = parallel_mode
return reduce_scatter(input_, dim, parallel_mode)
@staticmethod
def backward(ctx, output_grad):
input_grad = all_gather(output_grad, ctx.dim, ctx.parallel_mode)
return input_grad, None, None
def reduce_scatter_tensor_3d(tensor: Tensor, dim: int, parallel_mode: ParallelMode) -> Tensor:
""" """
Reduce-scatter the input.
:param tensor: Input tensor
:param dim: Dimension to scatter
:param parallel_mode: Parallel mode
"""
return _ReduceScatterTensor3D.apply(tensor, dim, parallel_mode)
class _ReduceByBatch3D(torch.autograd.Function):
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float32) @custom_fwd(cast_inputs=torch.float32)
def forward(ctx, def forward(ctx,
...@@ -274,16 +400,6 @@ class reduce_by_batch_3d(torch.autograd.Function): ...@@ -274,16 +400,6 @@ class reduce_by_batch_3d(torch.autograd.Function):
input_parallel_mode: ParallelMode, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor: reduce_mean: bool = False) -> Tensor:
"""
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:type reduce_mean: int, optional
"""
output = all_reduce(input_, input_parallel_mode) output = all_reduce(input_, input_parallel_mode)
output = all_reduce(output, weight_parallel_mode) output = all_reduce(output, weight_parallel_mode)
ctx.reduce_mean = reduce_mean ctx.reduce_mean = reduce_mean
...@@ -302,7 +418,26 @@ class reduce_by_batch_3d(torch.autograd.Function): ...@@ -302,7 +418,26 @@ class reduce_by_batch_3d(torch.autograd.Function):
return output_grad, None, None, None return output_grad, None, None, None
class broadcast_weight_3d_from_diagonal(torch.autograd.Function): def reduce_by_batch_3d(tensor: Tensor,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
reduce_mean: bool = False) -> Tensor:
"""
All-reduce the input from the model parallel region.
:param input_: input maxtrix
:type input_: torch.tensor
:param input_parallel_mode: input parallel mode
:type input_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param weight_parallel_mode: weight parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
:param reduce_mean: If set to ``True``, it will divide the output by (input parallel size * weight parallel size), default to False
:type reduce_mean: int, optional
"""
return _ReduceByBatch3D.apply(tensor, input_parallel_mode, weight_parallel_mode, reduce_mean)
class _BroadcastWeight3D_FromDiagonal(torch.autograd.Function):
""" """
broadcast weight from diagonal broadcast weight from diagonal
...@@ -315,6 +450,7 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function): ...@@ -315,6 +450,7 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
:param weight_parallel_mode: output parallel mode :param weight_parallel_mode: output parallel mode
:type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode :type weight_parallel_mode: colossalai.context.parallel_mode.ParallelMode
""" """
@staticmethod @staticmethod
@custom_fwd(cast_inputs=torch.float16) @custom_fwd(cast_inputs=torch.float16)
def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode, def forward(ctx, input_: Tensor, input_parallel_mode: ParallelMode, weight_parallel_mode: ParallelMode,
...@@ -337,3 +473,9 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function): ...@@ -337,3 +473,9 @@ class broadcast_weight_3d_from_diagonal(torch.autograd.Function):
else: else:
input_grad = None input_grad = None
return input_grad, None, None, None return input_grad, None, None, None
def broadcast_weight_3d_from_diagonal(tensor: Tensor, input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode, output_parallel_mode: ParallelMode) -> Tensor:
return _BroadcastWeight3D_FromDiagonal.apply(tensor, input_parallel_mode, weight_parallel_mode,
output_parallel_mode)
#!/usr/bin/env python from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
# -*- encoding: utf-8 -*-
import os
from colossalai.constants import (DEPTH_3D, INPUT_GROUP_3D, OUTPUT_GROUP_3D,
WEIGHT_GROUP_3D)
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from torch import Tensor from torch import Tensor
def get_depth_from_env() -> int: def get_depth_from_env() -> int:
try: try:
depth = os.environ[DEPTH_3D] depth = env.depth_3d
depth = int(depth)
assert depth > 0, 'DEPTH must be greater than zero' assert depth > 0, 'DEPTH must be greater than zero'
return depth return depth
except KeyError as e: except KeyError as e:
raise EnvironmentError( raise EnvironmentError('DEPTH is not found in the current environment, '
'DEPTH is not found in the current environment, ' 'please make sure that you have used the correct process group initializer')
'please make sure that you have used the correct process group initializer'
)
def get_parallel_mode_from_env(group): def get_parallel_mode_from_env(group):
return getattr(ParallelMode, os.environ[group]) assert group in [INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D], \
f'{group} is not valid for 3D tensor parallelism.'
return getattr(env, group)
def get_last_group(a, b): def get_last_group(a, b):
...@@ -35,8 +29,7 @@ def get_last_group(a, b): ...@@ -35,8 +29,7 @@ def get_last_group(a, b):
ParallelMode.PARALLEL_3D_OUTPUT: 'C', ParallelMode.PARALLEL_3D_OUTPUT: 'C',
} }
res = chr( res = chr(ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
ord('A') + ord('B') + ord('C') - ord(mapping[a]) - ord(mapping[b]))
if res == 'A': if res == 'A':
return ParallelMode.PARALLEL_3D_INPUT return ParallelMode.PARALLEL_3D_INPUT
...@@ -47,8 +40,7 @@ def get_last_group(a, b): ...@@ -47,8 +40,7 @@ def get_last_group(a, b):
def swap_in_out_group(): def swap_in_out_group():
os.environ[INPUT_GROUP_3D], os.environ[OUTPUT_GROUP_3D] = \ env.input_group_3d, env.output_group_3d = env.output_group_3d, env.input_group_3d
os.environ[OUTPUT_GROUP_3D], os.environ[INPUT_GROUP_3D]
def dbg_check_shape(tensor: Tensor, shape: tuple): def dbg_check_shape(tensor: Tensor, shape: tuple):
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math import math
from typing import Callable from typing import Callable
...@@ -10,11 +8,12 @@ from colossalai.communication import all_reduce, broadcast ...@@ -10,11 +8,12 @@ from colossalai.communication import all_reduce, broadcast
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.nn.layer.base_layer import ParallelLayer from colossalai.nn.layer.base_layer import ParallelLayer
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch.nn import Parameter from torch.nn import Parameter
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
...@@ -37,7 +36,8 @@ class LayerNorm3D(ParallelLayer): ...@@ -37,7 +36,8 @@ class LayerNorm3D(ParallelLayer):
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
""" """
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype: dtype = None):
def __init__(self, normalized_shape: int, eps: float = 1e-12, dtype=None):
super().__init__() super().__init__()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D) self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D) self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
...@@ -62,8 +62,8 @@ class LayerNorm3D(ParallelLayer): ...@@ -62,8 +62,8 @@ class LayerNorm3D(ParallelLayer):
init.ones_()(self.weight) init.ones_()(self.weight)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return layernorm_3d.apply(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon, return layernorm_3d(input_, self.weight, self.bias, self.normalized_shape, self.variance_epsilon,
self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode) self.input_parallel_mode, self.weight_parallel_mode, self.output_parallel_mode)
@LAYERS.register_module @LAYERS.register_module
...@@ -84,11 +84,12 @@ class Linear3D(ParallelLayer): ...@@ -84,11 +84,12 @@ class Linear3D(ParallelLayer):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
...@@ -136,8 +137,8 @@ class Linear3D(ParallelLayer): ...@@ -136,8 +137,8 @@ class Linear3D(ParallelLayer):
broadcast(self.bias, output_src_rank, self.output_parallel_mode) broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return linear_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, return linear_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode) self.output_parallel_mode)
@LAYERS.register_module @LAYERS.register_module
...@@ -160,12 +161,13 @@ class Classifier3D(ParallelLayer): ...@@ -160,12 +161,13 @@ class Classifier3D(ParallelLayer):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: Parameter = None, weight: Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
...@@ -214,8 +216,94 @@ class Classifier3D(ParallelLayer): ...@@ -214,8 +216,94 @@ class Classifier3D(ParallelLayer):
broadcast(self.bias, input_src_rank, self.input_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
return classifier_3d.apply(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode, return classifier_3d(input_, self.weight, self.bias, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode) self.output_parallel_mode)
@LAYERS.register_module
class VocabParallelClassifier3D(ParallelLayer):
"""
Vocab parallel classifier layer for 2D parallelism
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to ``True``
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(num_classes, self.depth)
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = Parameter(
torch.empty(self.out_features_per_partition,
self.in_features_per_partition,
device=get_current_device(),
dtype=dtype))
self.has_weight = True
if bias:
self.bias = Parameter(torch.zeros(self.out_features_per_partition, device=get_current_device(),
dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
self._set_tensor_parallel_attributes()
swap_in_out_group()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self) -> None:
if self.has_weight:
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
if self.bias is not None:
set_tensor_parallel_attribute_by_partition(self.bias, self.depth)
def reset_parameters(self, weight_initializer, bias_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.in_features, self.num_classes
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
output_src_rank = gpc.get_ranks_in_group(self.output_parallel_mode)[0]
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, output_src_rank, self.output_parallel_mode)
def forward(self, input_: Tensor) -> Tensor:
return linear_3d(input_, self.weight.transpose(0, 1), self.bias, self.input_parallel_mode,
self.weight_parallel_mode, self.output_parallel_mode)
@LAYERS.register_module @LAYERS.register_module
...@@ -242,13 +330,14 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -242,13 +330,14 @@ class PatchEmbedding3D(ParallelLayer):
:param position_embed_initializer: The intializer of position embedding, defaults to zero :param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional :type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
...@@ -284,8 +373,8 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -284,8 +373,8 @@ class PatchEmbedding3D(ParallelLayer):
set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth) set_tensor_parallel_attribute_by_partition(self.cls_token, self.depth)
set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth) set_tensor_parallel_attribute_by_partition(self.pos_embed, self.depth)
def _sync_grad_hook(self, grad) -> None: def _sync_grad_hook(self, grad) -> Tensor:
grad = all_reduce(grad, self.input_parallel_mode) grad = all_reduce(grad.clone(), self.input_parallel_mode)
grad = all_reduce(grad, self.weight_parallel_mode) grad = all_reduce(grad, self.weight_parallel_mode)
return grad return grad
...@@ -302,17 +391,19 @@ class PatchEmbedding3D(ParallelLayer): ...@@ -302,17 +391,19 @@ class PatchEmbedding3D(ParallelLayer):
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode) broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
broadcast(self.bias, weight_src_rank, self.weight_parallel_mode) broadcast(self.bias, weight_src_rank, self.weight_parallel_mode)
broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode) broadcast(self.pos_embed, weight_src_rank, self.weight_parallel_mode)
broadcast(self.weight, input_src_rank, self.input_parallel_mode)
broadcast(self.bias, input_src_rank, self.input_parallel_mode) broadcast(self.bias, input_src_rank, self.input_parallel_mode)
broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode) broadcast(self.pos_embed, input_src_rank, self.input_parallel_mode)
self.weight.register_hook(self._sync_grad_hook)
self.bias.register_hook(self._sync_grad_hook) self.bias.register_hook(self._sync_grad_hook)
self.cls_token.register_hook(self._sync_grad_hook) self.cls_token.register_hook(self._sync_grad_hook)
self.pos_embed.register_hook(self._sync_grad_hook) self.pos_embed.register_hook(self._sync_grad_hook)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
self.weight_parallel_mode, self.output_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
output = F.conv2d(input_, weight, self.bias, stride=self.patch_size) output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten: if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
...@@ -341,11 +432,12 @@ class Embedding3D(ParallelLayer): ...@@ -341,11 +432,12 @@ class Embedding3D(ParallelLayer):
:param args: Args used in F.embedding :param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding :param kwargs: Kwargs used in F.embedding
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
*args, *args,
**kwargs): **kwargs):
...@@ -385,8 +477,95 @@ class Embedding3D(ParallelLayer): ...@@ -385,8 +477,95 @@ class Embedding3D(ParallelLayer):
self.weight[self.padding_idx].fill_(0) self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor: def forward(self, input_: Tensor) -> Tensor:
weight = broadcast_weight_3d_from_diagonal.apply(self.weight, self.input_parallel_mode, input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
self.weight_parallel_mode, self.output_parallel_mode) input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
weight = broadcast_weight_3d_from_diagonal(self.weight, self.input_parallel_mode, self.weight_parallel_mode,
self.output_parallel_mode)
output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) output = F.embedding(input_, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
return output return output
@LAYERS.register_module
class VocabParallelEmbedding3D(torch.nn.Module):
"""Embedding parallelized in the vocabulary dimension.
:param num_embeddings: number of embeddings
:type num_embeddings: int
:param embedding_dim: dimension of embedding
:type embedding_dim: int
:param padding_idx: index of padding, defaults to None
:type padding_idx: int, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to normal initializer
:type weight_initializer: typing.Callable, optional
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: torch.dtype = None,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
self.depth = get_depth_from_env()
self.input_parallel_mode = get_parallel_mode_from_env(INPUT_GROUP_3D)
self.weight_parallel_mode = get_parallel_mode_from_env(WEIGHT_GROUP_3D)
self.output_parallel_mode = get_last_group(self.input_parallel_mode, self.weight_parallel_mode)
self.num_embeddings_per_partition = divide(self.num_embeddings, self.depth)
self.embed_dim_per_partition = divide(self.embed_dim, self.depth)
vocab_parallel_rank = gpc.get_local_rank(self.input_parallel_mode)
self.vocab_start_index = vocab_parallel_rank * self.num_embeddings_per_partition
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition
self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim_per_partition),
device=get_current_device(),
dtype=dtype))
self.reset_parameters(weight_initializer)
self._set_tensor_parallel_attributes()
env.vocab_parallel = True
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute_by_partition(self.weight, self.depth**2)
def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()
weight_src_rank = gpc.get_ranks_in_group(self.weight_parallel_mode)[0]
broadcast(self.weight, weight_src_rank, self.weight_parallel_mode)
def _fill_padding_idx_with_zero(self) -> None:
if self.padding_idx is not None:
with torch.no_grad():
self.weight[self.padding_idx].fill_(0)
def forward(self, input_: Tensor) -> Tensor:
input_ = split_tensor_3d(input_, 0, self.weight_parallel_mode)
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index)
masked_input = input_.clone() - self.vocab_start_index
masked_input[input_mask] = 0
weight = reduce_grad_3d(self.weight, self.weight_parallel_mode)
output_parallel = F.embedding(masked_input, weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output_parallel[input_mask, :] = 0.
output = reduce_scatter_tensor_3d(output_parallel, 0, self.input_parallel_mode)
return output
...@@ -2,12 +2,12 @@ ...@@ -2,12 +2,12 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import collections.abc import collections.abc
import os
from itertools import repeat from itertools import repeat
import numpy as np import numpy as np
import torch import torch
from colossalai.constants import (IS_TENSOR_PARALLEL, NUM_PARTITIONS, TENSOR_PARALLEL_MODE) from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.utils import checkpoint from colossalai.utils import checkpoint
from torch import Tensor, nn from torch import Tensor, nn
...@@ -38,7 +38,7 @@ class CheckpointModule(nn.Module): ...@@ -38,7 +38,7 @@ class CheckpointModule(nn.Module):
def divide(numerator, denominator): def divide(numerator, denominator):
"""Only allow exact division """Only allow exact division
:param numerator: Numerator of the division :param numerator: Numerator of the division
:param denominator: Denominator of the division :param denominator: Denominator of the division
""" """
...@@ -65,7 +65,7 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions): ...@@ -65,7 +65,7 @@ def set_tensor_parallel_attribute_by_partition(param, num_partitions):
def get_tensor_parallel_mode(): def get_tensor_parallel_mode():
return os.environ[TENSOR_PARALLEL_MODE] return env.mode
# From PyTorch internals # From PyTorch internals
......
...@@ -3,14 +3,14 @@ from typing import Callable ...@@ -3,14 +3,14 @@ from typing import Callable
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from colossalai.context import seed
from colossalai.nn import init as init from colossalai.nn import init as init
from colossalai.registry import LAYERS from colossalai.registry import LAYERS
from colossalai.utils import get_current_device from colossalai.utils.cuda import get_current_device
from torch import Tensor, dtype from torch import Tensor
from torch import nn as nn from torch import nn as nn
from ..utils import to_2tuple from ..utils import to_2tuple
from colossalai.context import seed
def drop_path(x, drop_prob: float = 0., training: bool = False): def drop_path(x, drop_prob: float = 0., training: bool = False):
...@@ -36,6 +36,7 @@ class DropPath(nn.Module): ...@@ -36,6 +36,7 @@ class DropPath(nn.Module):
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
""" """
def __init__(self, drop_prob=None): def __init__(self, drop_prob=None):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
...@@ -47,6 +48,7 @@ class DropPath(nn.Module): ...@@ -47,6 +48,7 @@ class DropPath(nn.Module):
class WrappedDropout(nn.Module): class WrappedDropout(nn.Module):
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager. """Same as torch.nn.Dropout. But it is wrapped with the context of seed manager.
""" """
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None): def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
super().__init__() super().__init__()
if p < 0 or p > 1: if p < 0 or p > 1:
...@@ -75,6 +77,7 @@ class WrappedDropPath(nn.Module): ...@@ -75,6 +77,7 @@ class WrappedDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager. Here, it is wrapped with the context of seed manager.
""" """
def __init__(self, p: float = 0., mode=None): def __init__(self, p: float = 0., mode=None):
super().__init__() super().__init__()
self.p = p self.p = p
...@@ -120,13 +123,14 @@ class VanillaPatchEmbedding(nn.Module): ...@@ -120,13 +123,14 @@ class VanillaPatchEmbedding(nn.Module):
:param position_embed_initializer: The intializer of position embedding, defaults to zero :param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional :type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
img_size: int, img_size: int,
patch_size: int, patch_size: int,
in_chans: int, in_chans: int,
embed_size: int, embed_size: int,
dtype: dtype = None,
flatten: bool = True, flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()): position_embed_initializer: Callable = init.zeros_()):
...@@ -142,8 +146,9 @@ class VanillaPatchEmbedding(nn.Module): ...@@ -142,8 +146,9 @@ class VanillaPatchEmbedding(nn.Module):
self.weight = nn.Parameter( self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype)) torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype)) self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_size)) self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_size)) self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer) self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
...@@ -170,7 +175,7 @@ class VanillaPatchEmbedding(nn.Module): ...@@ -170,7 +175,7 @@ class VanillaPatchEmbedding(nn.Module):
@LAYERS.register_module @LAYERS.register_module
class VanillaClassifier(nn.Module): class VanillaClassifier(nn.Module):
""" """
Classifier for ViT Dense linear classifier
:param in_features: size of each input sample :param in_features: size of each input sample
:type in_features: int :type in_features: int
...@@ -187,12 +192,13 @@ class VanillaClassifier(nn.Module): ...@@ -187,12 +192,13 @@ class VanillaClassifier(nn.Module):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: nn.Parameter = None, weight: nn.Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)): bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__() super().__init__()
......
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.nn.layer.utils import get_tensor_parallel_mode
from torch import nn from torch import nn
from torch.nn.modules.loss import * from torch.nn.modules.loss import *
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.nn.layer.utils import get_tensor_parallel_mode from .loss_1d import VocabParallelCrossEntropyLoss1D
from .loss_2d import CrossEntropyLoss2D from .loss_2d import CrossEntropyLoss2D, VocabParallelCrossEntropyLoss2D
from .loss_2p5d import CrossEntropyLoss2p5D from .loss_2p5d import CrossEntropyLoss2p5D, VocabParallelCrossEntropyLoss2p5D
from .loss_3d import CrossEntropyLoss3D from .loss_3d import CrossEntropyLoss3D, VocabParallelCrossEntropyLoss3D
from .loss_moe import MoeCrossEntropyLoss, MoeLoss from .loss_moe import MoeCrossEntropyLoss, MoeLoss
_parallel_cross_entropy = { _parallel_cross_entropy = {
'2d': CrossEntropyLoss2D, '2d': CrossEntropyLoss2D,
'2.5d': CrossEntropyLoss2p5D, '2.5d': CrossEntropyLoss2p5D,
'3d': CrossEntropyLoss3D '3d': CrossEntropyLoss3D,
}
_vocab_parallel_cross_entropy = {
'1d': VocabParallelCrossEntropyLoss1D,
'2d': VocabParallelCrossEntropyLoss2D,
'2.5d': VocabParallelCrossEntropyLoss2p5D,
'3d': VocabParallelCrossEntropyLoss3D,
} }
class CrossEntropyLoss(_Loss): class CrossEntropyLoss(_Loss):
def __init__(self, reduction: bool = True, *args, **kwargs): def __init__(self, reduction: bool = True, *args, **kwargs):
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']: if tensor_parallel is not None and env.vocab_parallel:
self.loss = _vocab_parallel_cross_entropy[tensor_parallel](reduction=reduction, *args, **kwargs)
elif tensor_parallel is None or tensor_parallel == '1d':
reduction = 'mean' if reduction else 'none' reduction = 'mean' if reduction else 'none'
self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs) self.loss = nn.CrossEntropyLoss(reduction=reduction, *args, **kwargs)
else: else:
......
import torch
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.registry import LOSSES
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.modules.loss import _Loss
class _VocabParallelCrossEntropy1D(torch.autograd.Function):
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, vocab_parallel_logits, targets):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
vocab_start_index = partition_vocab_size * rank
vocab_end_index = vocab_start_index + partition_vocab_size
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (targets < vocab_start_index) | (targets >= vocab_end_index)
masked_target = targets.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0], device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(targets)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss1D(_Loss):
"""
Vocab parallel cross entropy loss for 1D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
loss = _VocabParallelCrossEntropy1D.apply(logits, targets)
if self.reduction_mean:
loss = loss.mean()
return loss
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d import torch
import torch.distributed as dist
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d import reduce_by_batch_2d, split_tensor_2d
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
from torch.nn.functional import cross_entropy from torch.nn.functional import cross_entropy
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
...@@ -16,6 +22,7 @@ class CrossEntropyLoss2D(_Loss): ...@@ -16,6 +22,7 @@ class CrossEntropyLoss2D(_Loss):
:type reduction: bool, optional :type reduction: bool, optional
""" """
def __init__(self, reduction=True, *args, **kwargs): def __init__(self, reduction=True, *args, **kwargs):
super().__init__() super().__init__()
assert_summa_initialization() assert_summa_initialization()
...@@ -29,8 +36,110 @@ class CrossEntropyLoss2D(_Loss): ...@@ -29,8 +36,110 @@ class CrossEntropyLoss2D(_Loss):
:param logits: Output logits of model :param logits: Output logits of model
:param targets: True targets from data :param targets: True targets from data
""" """
targets = split_tensor_2d(targets)
loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs) loss = cross_entropy(logits, targets, reduction='none', *self.loss_args, **self.loss_kwargs)
if self.reduction_mean: if self.reduction_mean:
loss = loss.mean() loss = loss.mean()
loss = reduce_by_batch_2d.apply(loss, True) loss = reduce_by_batch_2d(loss, True)
return loss
class _VocabParallelCrossEntropy2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, logits, targets):
# logits: [b/q, h/q]
# labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
# Subtract the maximum value.
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0,
end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
@custom_bwd
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0], device=get_current_device())
grad_2d[arange_1d, masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
@LOSSES.register_module
class VocabParallelCrossEntropyLoss2D(_Loss):
"""
Vocab parallel cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
self.reduction_mean = reduction
def forward(self, logits, targets):
"""Calculate loss between logits and targets
:param logits: Output logits of model
:param targets: True targets from data
"""
targets = split_tensor_2d(targets)
loss = _VocabParallelCrossEntropy2D.apply(
logits,
targets,
)
if self.reduction_mean:
loss = loss.mean()
loss = reduce_by_batch_2d(loss, True)
return loss return loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment