Commit 9ee197d0 authored by アマデウス's avatar アマデウス Committed by Frank Lee
Browse files

moved env variables to global variables; (#215)

added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
parent b82d60be
......@@ -137,8 +137,4 @@ dmypy.json
.DS_Store
#data/
# launcher setting
tests/launcher/log
tests/launcher/personal
docs/.build
......@@ -5,7 +5,7 @@ repos:
- id: yapf
args: ['--style=google', '--parallel', '--in-place']
- repo: https://github.com/pycqa/flake8
rev: ''
rev: '4.0.1'
hooks:
- id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format
......
......@@ -4,8 +4,9 @@
import torch.nn as nn
try:
import apex.amp as apex_amp
except:
pass
except ImportError:
raise ImportError('Cannot import apex.amp correctly.')
from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer
......
......@@ -30,7 +30,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
"""
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = [tensor]
out = tensor
work = None
else:
shape = list(tensor.shape)
......@@ -96,34 +96,40 @@ def all_reduce(tensor: Tensor,
async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
out = tensor.contiguous()
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return tensor, work
return out, work
else:
return tensor
return out
......@@ -19,23 +19,12 @@ INITIALIZER_MAPPING = {
'moe': 'Initializer_Moe'
}
# 1D parallel
PARALLEL_INPUT_1D = 'parallel_input_1d'
# 3D parallelism groups
INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'
# 2D paralllel
SUMMA_DIM = 'SUMMA_DIM'
# 2.5D paralllel
TESSERACT_DIM = 'TESSERACT_DIM'
TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel
DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
# Tensor parallel attributes
# Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
......@@ -8,14 +8,15 @@ from typing import Union
import numpy as np
import torch
import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config
from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
from colossalai.global_variables import moe_env
class ParallelContext:
......@@ -307,7 +308,6 @@ class ParallelContext:
port: int
):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
......@@ -389,7 +389,8 @@ class ParallelContext:
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode)
env.mode = tensor_parallel_mode
self.check_sanity()
pg_init = []
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import torch.distributed as dist
from colossalai.context import Config
import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D
from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer):
"""A ProcessGroupInitializer for 1d tensor parallelism.
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
"""
'''A ProcessGroupInitializer for 1d tensor parallelism.
'''
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
......@@ -24,7 +20,7 @@ class Initializer_1D(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple
"""
......@@ -33,7 +29,7 @@ class Initializer_1D(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = ''
env.parallel_input_1d = False
for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
......
import math
import os
import torch.distributed as dist
from colossalai.constants import SUMMA_DIM
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from colossalai.global_variables import tensor_parallel_env as env
def _check_summa_env_var(summa_dim):
# check environment variable for SUMMA
env_summa_dim = os.environ.get(SUMMA_DIM, None)
env_summa_dim = env.summa_dim
if env_summa_dim:
assert int(env_summa_dim) == summa_dim, \
'SUMMA_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[SUMMA_DIM] = str(summa_dim)
env.summa_dim = summa_dim
class Initializer_2D_Row(ProcessGroupInitializer):
"""2d tensor parallel initialization among rows.
:param num_group: The number of all tensor groups
:param summa_dim: The dimension of SUMMA
:param args: Args used to initialize base class
:param kwargs: Kwargs used to initialize base class
:type num_group: int
:type summa_dim: int
"""
......@@ -132,7 +129,7 @@ class Initializer_2D(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
......
......@@ -2,22 +2,21 @@
# -*- encoding: utf-8 -*-
import math
import os
import torch.distributed as dist
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_tesseract_env_var(tesseract_dim: int,
tesseract_dep: int):
# check environment variable for TESSERACT
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None)
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None)
# check global variable for TESSERACT
env_tesseract_dim = env.tesseract_dim
env_tesseract_dep = env.tesseract_dep
if env_tesseract_dim and env_tesseract_dep:
assert int(env_tesseract_dim) == tesseract_dim, \
......@@ -27,8 +26,8 @@ def _check_tesseract_env_var(tesseract_dim: int,
'TESSERACT_DEP has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[TESSERACT_DIM] = str(tesseract_dim)
os.environ[TESSERACT_DEP] = str(tesseract_dep)
env.tesseract_dim = tesseract_dim
env.tesseract_dep = tesseract_dep
# i row j col k dep
......@@ -245,7 +244,6 @@ class Initializer_2p5D(ProcessGroupInitializer):
:param pipeline_parallel_size: Size of pipeline parallel
:param tensor_parallel_size: Size of tensor parallel
:param depth: The depth of 2p5d parallel
:type rank: int
:type world_size: int
:type config: Config
......@@ -281,7 +279,7 @@ class Initializer_2p5D(ProcessGroupInitializer):
def init_dist_group(self):
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: Whole 2p5D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
......
......@@ -2,10 +2,9 @@
# -*- encoding: utf-8 -*-
import math
import os
import torch.distributed as dist
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode
......@@ -13,15 +12,15 @@ from .process_group_initializer import ProcessGroupInitializer
def _check_depth_env_var(depth):
# check environment variable for SUMMA
env_depth = os.environ.get(DEPTH_3D, None)
# check global variable
env_depth = env.depth_3d
if env_depth:
assert int(env_depth) == depth, \
'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized'
else:
os.environ[DEPTH_3D] = str(depth)
env.depth_3d = depth
class Initializer_3D_Input(ProcessGroupInitializer):
......@@ -34,6 +33,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
:type num_group: int
:type depth: int
"""
def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args)
self.num_group = num_group
......@@ -50,15 +50,12 @@ class Initializer_3D_Input(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D
env.input_group_3d = mode
for h in range(self.num_group):
for i in range(self.depth):
for k in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for j in range(self.depth)
]
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
group = dist.new_group(ranks)
if self.rank in ranks:
......@@ -97,15 +94,12 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D
env.weight_group_3d = mode
for h in range(self.num_group):
for k in range(self.depth):
for j in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for i in range(self.depth)
]
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
group = dist.new_group(ranks)
if self.rank in ranks:
......@@ -118,7 +112,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
class Initializer_3D_Output(ProcessGroupInitializer):
"""3D tensor parallel initialization among weight.
"""3D tensor parallel initialization among output.
:param num_group: The number of all tensor groups
:param depth: Depth of 3D parallelism
......@@ -144,15 +138,12 @@ class Initializer_3D_Output(ProcessGroupInitializer):
process_group = None
group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D
env.output_group_3d = mode
for h in range(self.num_group):
for i in range(self.depth):
for j in range(self.depth):
ranks = [
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for k in range(self.depth)
]
ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
group = dist.new_group(ranks)
if self.rank in ranks:
......@@ -170,6 +161,7 @@ class Initializer_3D(ProcessGroupInitializer):
:param args: Args used to initialize ProcessGroupInitializer
"""
def __init__(self, *args):
super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size
......@@ -178,16 +170,13 @@ class Initializer_3D(ProcessGroupInitializer):
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})'
_check_depth_env_var(self.depth)
self.input_initializer = Initializer_3D_Input(self.num_group,
self.depth, *args)
self.weight_initializer = Initializer_3D_Weight(
self.num_group, self.depth, *args)
self.output_initializer = Initializer_3D_Output(
self.num_group, self.depth, *args)
self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args)
self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args)
def init_dist_group(self):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
"""
......
......@@ -9,4 +9,4 @@ from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
'MoeGradientHandler', 'SequenceParallelGradientHandler']
'MoeGradientHandler', 'SequenceParallelGradientHandler']
\ No newline at end of file
......@@ -9,7 +9,6 @@ from typing import Iterable, Callable
from .._base_engine import Engine
from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC):
......@@ -69,7 +68,6 @@ class BaseSchedule(ABC):
self.batch_size = data.size(0)
else:
self.batch_size = next(iter(data.values())).size(0)
data, label = split_batch(data), split_batch(label)
if to_gpu:
return self._move_to_device(data), self._move_to_device(label)
return data, label
......
from typing import Optional
class TensorParallelEnv(object):
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = object.__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, *args, **kwargs):
self.load(*args, **kwargs)
def load(self,
mode: Optional[str] = None,
vocab_parallel: bool = False,
parallel_input_1d: bool = False,
summa_dim: int = None,
tesseract_dim: int = None,
tesseract_dep: int = None,
depth_3d: int = None,
input_group_3d=None,
weight_group_3d=None,
output_group_3d=None):
self.mode = mode
self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d
self.summa_dim = summa_dim
self.tesseract_dim = tesseract_dim
self.tesseract_dep = tesseract_dep
self.depth_3d = depth_3d
self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d
def save(self):
return dict(mode=self.mode,
vocab_parallel=self.vocab_parallel,
parallel_input_1d=self.parallel_input_1d,
summa_dim=self.summa_dim,
tesseract_dim=self.tesseract_dim,
tesseract_dep=self.tesseract_dep,
depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d)
class MoeEnv:
......@@ -33,4 +81,6 @@ class MoeEnv:
return self.aux_loss
tensor_parallel_env = TensorParallelEnv()
moe_env = MoeEnv()
......@@ -37,17 +37,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
= colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5):
def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
super(MixedFusedLayerNorm, self).__init__()
global colossal_layer_norm_cuda
......@@ -61,8 +61,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape))
self.bias = Parameter(torch.Tensor(*normalized_shape))
self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
self.reset_parameters()
def reset_parameters(self):
......
from ._utils import split_batch
from ._utils import partition_batch
from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear
from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch']
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']
......@@ -2,13 +2,13 @@ from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d
from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d}
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
def split_batch(input_) -> Tensor:
def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict):
......
from contextlib import nullcontext
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode
......@@ -26,6 +23,8 @@ class Dropout(nn.Module):
self.drop = nn.Dropout(p, inplace)
def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR)
with cm:
if self.tensor_parallel in [None, '1d']:
return self.drop(*args)
else:
with seed(ParallelMode.TENSOR):
return self.drop(*args)
import math
from typing import Callable, Optional
from typing import Callable
from colossalai.utils import get_current_device
from torch import dtype, nn
......@@ -12,10 +12,21 @@ from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
_parallel_embedding = {
'2d': Embedding2D,
'2.5d': Embedding2p5D,
'3d': Embedding3D,
}
_vocab_parallel_embedding = {
'1d': VocabParallelEmbedding1D,
'2d': VocabParallelEmbedding2D,
'2.5d': VocabParallelEmbedding2p5D,
'3d': VocabParallelEmbedding3D
}
_parallel_patchembedding = {
'None': VanillaPatchEmbedding,
None: VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D,
......@@ -40,26 +51,23 @@ class Embedding(nn.Module):
:param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding
"""
def __init__(self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int = None,
dtype: dtype = None,
weight_initializer: Callable = init.normal_(),
vocab_parallel_limit: int = 2048,
*args,
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.embed = nn.Embedding(num_embeddings,
embedding_dim,
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
if tensor_parallel is None or (tensor_parallel == '1d' and num_embeddings <= vocab_parallel_limit):
self.embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
**kwargs).to(dtype).to(get_current_device())
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else:
elif num_embeddings <= vocab_parallel_limit:
self.embed = _parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
......@@ -69,6 +77,16 @@ class Embedding(nn.Module):
*args,
**kwargs,
)
else:
self.embed = _vocab_parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property
def weight(self):
......@@ -101,16 +119,19 @@ class PatchEmbedding(nn.Module):
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()) -> None:
def __init__(
self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
dtype: dtype = None,
flatten: bool = True,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()
) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel](
......
import math
from typing import Callable, Optional
from typing import Callable
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device
from torch import dtype, nn
......@@ -16,13 +15,20 @@ from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = {
'None': VanillaClassifier,
None: VanillaClassifier,
'1d': Classifier1D,
'2d': Classifier2D,
'2.5d': Classifier2p5D,
'3d': Classifier3D
}
_vocab_parallel_classifier = {
'1d': VocabParallelClassifier1D,
'2d': VocabParallelClassifier2D,
'2.5d': VocabParallelClassifier2p5D,
'3d': VocabParallelClassifier3D
}
class Linear(nn.Module):
"""
......@@ -40,8 +46,9 @@ class Linear(nn.Module):
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param kwargs: Kwargs used for initialization
:param kwargs: Kwargs used for particular parallelisms
"""
def __init__(self,
in_features: int,
out_features: int,
......@@ -52,10 +59,10 @@ class Linear(nn.Module):
**kwargs) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None':
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype)
if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias:
if self.layer.bias is not None:
bias_initializer(self.layer.bias, fan_in=in_features)
else:
self.layer = _parallel_linear[tensor_parallel](
......@@ -97,26 +104,38 @@ class Classifier(nn.Module):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(
self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)
) -> None:
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
vocab_parallel_limit: int = 2048) -> None:
super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
tensor_parallel = get_tensor_parallel_mode()
if num_classes <= vocab_parallel_limit or tensor_parallel is None:
self.layer = _parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
else:
self.layer = _vocab_parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property
def weight(self):
......
from typing import Optional
from colossalai.utils import get_current_device
from torch import nn
from colossalai import kernel
from ... import init as init
from ..parallel_1d import *
......@@ -11,7 +10,12 @@ from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode
from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D}
_parallel_layernorm = {
'1d': kernel.LayerNorm,
'2d': LayerNorm2D,
'2.5d': LayerNorm2p5D,
'3d': LayerNorm3D
}
class LayerNorm(nn.Module):
......@@ -28,11 +32,12 @@ class LayerNorm(nn.Module):
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
"""
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__()
tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype)
if tensor_parallel is None:
self.norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment