Commit 9ee197d0 authored by アマデウス's avatar アマデウス Committed by Frank Lee
Browse files

moved env variables to global variables; (#215)

added branch context;
added vocab parallel layers;
moved split_batch from load_batch to tensor parallel embedding layers;
updated gpt model;
updated unit test cases;
fixed few collective communicator bugs
parent b82d60be
...@@ -137,8 +137,4 @@ dmypy.json ...@@ -137,8 +137,4 @@ dmypy.json
.DS_Store .DS_Store
#data/ #data/
# launcher setting
tests/launcher/log
tests/launcher/personal
docs/.build docs/.build
...@@ -5,7 +5,7 @@ repos: ...@@ -5,7 +5,7 @@ repos:
- id: yapf - id: yapf
args: ['--style=google', '--parallel', '--in-place'] args: ['--style=google', '--parallel', '--in-place']
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: '' rev: '4.0.1'
hooks: hooks:
- id: flake8 - id: flake8
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
......
...@@ -4,8 +4,9 @@ ...@@ -4,8 +4,9 @@
import torch.nn as nn import torch.nn as nn
try: try:
import apex.amp as apex_amp import apex.amp as apex_amp
except: except ImportError:
pass raise ImportError('Cannot import apex.amp correctly.')
from torch import Tensor from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.nn.optimizer import ColossalaiOptimizer
......
...@@ -30,7 +30,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: ...@@ -30,7 +30,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = [tensor] out = tensor
work = None work = None
else: else:
shape = list(tensor.shape) shape = list(tensor.shape)
...@@ -96,34 +96,40 @@ def all_reduce(tensor: Tensor, ...@@ -96,34 +96,40 @@ def all_reduce(tensor: Tensor,
async_op: bool = False) -> Tensor: async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor
work = None work = None
else: else:
work = dist.all_reduce(tensor.contiguous(), op=op, group=gpc.get_group(parallel_mode), async_op=async_op) out = tensor.contiguous()
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op: if async_op:
return tensor, work return out, work
else: else:
return tensor return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False): def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor
work = None work = None
else: else:
work = dist.broadcast(tensor.contiguous(), src=src, group=gpc.get_group(parallel_mode), async_op=async_op) out = tensor.contiguous()
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op: if async_op:
return tensor, work return out, work
else: else:
return tensor return out
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False): def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
if depth == 1: if depth == 1:
out = tensor
work = None work = None
else: else:
work = dist.reduce(tensor.contiguous(), dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op) out = tensor.contiguous()
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op: if async_op:
return tensor, work return out, work
else: else:
return tensor return out
...@@ -19,23 +19,12 @@ INITIALIZER_MAPPING = { ...@@ -19,23 +19,12 @@ INITIALIZER_MAPPING = {
'moe': 'Initializer_Moe' 'moe': 'Initializer_Moe'
} }
# 1D parallel # 3D parallelism groups
PARALLEL_INPUT_1D = 'parallel_input_1d' INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'
# 2D paralllel # Attributes of tensor parallel parameters
SUMMA_DIM = 'SUMMA_DIM'
# 2.5D paralllel
TESSERACT_DIM = 'TESSERACT_DIM'
TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel
DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
# Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel' IS_TENSOR_PARALLEL = 'is_tensor_parallel'
NUM_PARTITIONS = 'num_partitions' NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS] TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
...@@ -8,14 +8,15 @@ from typing import Union ...@@ -8,14 +8,15 @@ from typing import Union
import numpy as np import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING, TENSOR_PARALLEL_MODE from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
from colossalai.global_variables import moe_env
class ParallelContext: class ParallelContext:
...@@ -307,7 +308,6 @@ class ParallelContext: ...@@ -307,7 +308,6 @@ class ParallelContext:
port: int port: int
): ):
"""Initializes the global distributed environment """Initializes the global distributed environment
:param rank: rank for the default process group :param rank: rank for the default process group
:type rank: int :type rank: int
:param world_size: world size of the default process group :param world_size: world size of the default process group
...@@ -389,7 +389,8 @@ class ParallelContext: ...@@ -389,7 +389,8 @@ class ParallelContext:
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']: if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode'] tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}" assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
os.environ[TENSOR_PARALLEL_MODE] = str(tensor_parallel_mode) env.mode = tensor_parallel_mode
self.check_sanity() self.check_sanity()
pg_init = [] pg_init = []
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import os
import torch.distributed as dist
from colossalai.context import Config import torch.distributed as dist
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from colossalai.constants import PARALLEL_INPUT_1D from .process_group_initializer import ProcessGroupInitializer
@DIST_GROUP_INITIALIZER.register_module @DIST_GROUP_INITIALIZER.register_module
class Initializer_1D(ProcessGroupInitializer): class Initializer_1D(ProcessGroupInitializer):
"""A ProcessGroupInitializer for 1d tensor parallelism. '''A ProcessGroupInitializer for 1d tensor parallelism.
'''
:param args: Args used to initialize ProcessGroupInitializer
:param kwargs: Kwargs used to initialize ProcessGroupInitializer
"""
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
...@@ -24,7 +20,7 @@ class Initializer_1D(ProcessGroupInitializer): ...@@ -24,7 +20,7 @@ class Initializer_1D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu. """Initialize 1D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: (local_rank, group_world_size, process_group, ranks_in_group, mode) :return: (local_rank, group_world_size, process_group, ranks_in_group, mode)
:rtype: Tuple :rtype: Tuple
""" """
...@@ -33,7 +29,7 @@ class Initializer_1D(ProcessGroupInitializer): ...@@ -33,7 +29,7 @@ class Initializer_1D(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_1D mode = ParallelMode.PARALLEL_1D
os.environ[PARALLEL_INPUT_1D] = '' env.parallel_input_1d = False
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
......
import math import math
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import SUMMA_DIM
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from colossalai.global_variables import tensor_parallel_env as env
def _check_summa_env_var(summa_dim): def _check_summa_env_var(summa_dim):
# check environment variable for SUMMA # check environment variable for SUMMA
env_summa_dim = os.environ.get(SUMMA_DIM, None) env_summa_dim = env.summa_dim
if env_summa_dim: if env_summa_dim:
assert int(env_summa_dim) == summa_dim, \ assert int(env_summa_dim) == summa_dim, \
'SUMMA_DIM has been set in the current environment and ' \ 'SUMMA_DIM has been set in the current environment and ' \
'does not match with the value passed to this initialized' 'does not match with the value passed to this initialized'
else: else:
os.environ[SUMMA_DIM] = str(summa_dim) env.summa_dim = summa_dim
class Initializer_2D_Row(ProcessGroupInitializer): class Initializer_2D_Row(ProcessGroupInitializer):
"""2d tensor parallel initialization among rows. """2d tensor parallel initialization among rows.
:param num_group: The number of all tensor groups :param num_group: The number of all tensor groups
:param summa_dim: The dimension of SUMMA :param summa_dim: The dimension of SUMMA
:param args: Args used to initialize base class :param args: Args used to initialize base class
:param kwargs: Kwargs used to initialize base class :param kwargs: Kwargs used to initialize base class
:type num_group: int :type num_group: int
:type summa_dim: int :type summa_dim: int
""" """
...@@ -132,7 +129,7 @@ class Initializer_2D(ProcessGroupInitializer): ...@@ -132,7 +129,7 @@ class Initializer_2D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2D tensor row and col parallel groups, and assign local_ranks and groups to each gpu.
:return: 2D tensor parallelism's information :return: 2D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
""" """
......
...@@ -2,22 +2,21 @@ ...@@ -2,22 +2,21 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import TESSERACT_DIM, TESSERACT_DEP
from colossalai.context import Config from colossalai.context import Config
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from .process_group_initializer import ProcessGroupInitializer
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
from .process_group_initializer import ProcessGroupInitializer
def _check_tesseract_env_var(tesseract_dim: int, def _check_tesseract_env_var(tesseract_dim: int,
tesseract_dep: int): tesseract_dep: int):
# check environment variable for TESSERACT # check global variable for TESSERACT
env_tesseract_dim = os.environ.get(TESSERACT_DIM, None) env_tesseract_dim = env.tesseract_dim
env_tesseract_dep = os.environ.get(TESSERACT_DEP, None) env_tesseract_dep = env.tesseract_dep
if env_tesseract_dim and env_tesseract_dep: if env_tesseract_dim and env_tesseract_dep:
assert int(env_tesseract_dim) == tesseract_dim, \ assert int(env_tesseract_dim) == tesseract_dim, \
...@@ -27,8 +26,8 @@ def _check_tesseract_env_var(tesseract_dim: int, ...@@ -27,8 +26,8 @@ def _check_tesseract_env_var(tesseract_dim: int,
'TESSERACT_DEP has been set in the current environment and ' \ 'TESSERACT_DEP has been set in the current environment and ' \
'does not match with the value passed to this initialized' 'does not match with the value passed to this initialized'
else: else:
os.environ[TESSERACT_DIM] = str(tesseract_dim) env.tesseract_dim = tesseract_dim
os.environ[TESSERACT_DEP] = str(tesseract_dep) env.tesseract_dep = tesseract_dep
# i row j col k dep # i row j col k dep
...@@ -245,7 +244,6 @@ class Initializer_2p5D(ProcessGroupInitializer): ...@@ -245,7 +244,6 @@ class Initializer_2p5D(ProcessGroupInitializer):
:param pipeline_parallel_size: Size of pipeline parallel :param pipeline_parallel_size: Size of pipeline parallel
:param tensor_parallel_size: Size of tensor parallel :param tensor_parallel_size: Size of tensor parallel
:param depth: The depth of 2p5d parallel :param depth: The depth of 2p5d parallel
:type rank: int :type rank: int
:type world_size: int :type world_size: int
:type config: Config :type config: Config
...@@ -281,7 +279,7 @@ class Initializer_2p5D(ProcessGroupInitializer): ...@@ -281,7 +279,7 @@ class Initializer_2p5D(ProcessGroupInitializer):
def init_dist_group(self): def init_dist_group(self):
"""Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu. """Initialize 2p5D tensor row, col, depth, and colXdepth parallel groups, and assign local_ranks and groups to each gpu.
:return: Whole 2p5D tensor parallelism's information :return: Whole 2p5D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
""" """
......
...@@ -2,10 +2,9 @@ ...@@ -2,10 +2,9 @@
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
import math import math
import os
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import DEPTH_3D, INPUT_GROUP_3D, WEIGHT_GROUP_3D, OUTPUT_GROUP_3D from colossalai.global_variables import tensor_parallel_env as env
from colossalai.registry import DIST_GROUP_INITIALIZER from colossalai.registry import DIST_GROUP_INITIALIZER
from ..parallel_mode import ParallelMode from ..parallel_mode import ParallelMode
...@@ -13,15 +12,15 @@ from .process_group_initializer import ProcessGroupInitializer ...@@ -13,15 +12,15 @@ from .process_group_initializer import ProcessGroupInitializer
def _check_depth_env_var(depth): def _check_depth_env_var(depth):
# check environment variable for SUMMA # check global variable
env_depth = os.environ.get(DEPTH_3D, None) env_depth = env.depth_3d
if env_depth: if env_depth:
assert int(env_depth) == depth, \ assert int(env_depth) == depth, \
'DEPTH_3D has been set in the current environment and ' \ 'DEPTH_3D has been set in the current environment and ' \
'does not match with the value passed to this initialized' 'does not match with the value passed to this initialized'
else: else:
os.environ[DEPTH_3D] = str(depth) env.depth_3d = depth
class Initializer_3D_Input(ProcessGroupInitializer): class Initializer_3D_Input(ProcessGroupInitializer):
...@@ -34,6 +33,7 @@ class Initializer_3D_Input(ProcessGroupInitializer): ...@@ -34,6 +33,7 @@ class Initializer_3D_Input(ProcessGroupInitializer):
:type num_group: int :type num_group: int
:type depth: int :type depth: int
""" """
def __init__(self, num_group: int, depth: int, *args): def __init__(self, num_group: int, depth: int, *args):
super().__init__(*args) super().__init__(*args)
self.num_group = num_group self.num_group = num_group
...@@ -50,15 +50,12 @@ class Initializer_3D_Input(ProcessGroupInitializer): ...@@ -50,15 +50,12 @@ class Initializer_3D_Input(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_INPUT mode = ParallelMode.PARALLEL_3D_INPUT
os.environ[INPUT_GROUP_3D] = INPUT_GROUP_3D env.input_group_3d = mode
for h in range(self.num_group): for h in range(self.num_group):
for i in range(self.depth): for i in range(self.depth):
for k in range(self.depth): for k in range(self.depth):
ranks = [ ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for j in range(self.depth)]
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for j in range(self.depth)
]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
...@@ -97,15 +94,12 @@ class Initializer_3D_Weight(ProcessGroupInitializer): ...@@ -97,15 +94,12 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_WEIGHT mode = ParallelMode.PARALLEL_3D_WEIGHT
os.environ[WEIGHT_GROUP_3D] = WEIGHT_GROUP_3D env.weight_group_3d = mode
for h in range(self.num_group): for h in range(self.num_group):
for k in range(self.depth): for k in range(self.depth):
for j in range(self.depth): for j in range(self.depth):
ranks = [ ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for i in range(self.depth)]
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for i in range(self.depth)
]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
...@@ -118,7 +112,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer): ...@@ -118,7 +112,7 @@ class Initializer_3D_Weight(ProcessGroupInitializer):
class Initializer_3D_Output(ProcessGroupInitializer): class Initializer_3D_Output(ProcessGroupInitializer):
"""3D tensor parallel initialization among weight. """3D tensor parallel initialization among output.
:param num_group: The number of all tensor groups :param num_group: The number of all tensor groups
:param depth: Depth of 3D parallelism :param depth: Depth of 3D parallelism
...@@ -144,15 +138,12 @@ class Initializer_3D_Output(ProcessGroupInitializer): ...@@ -144,15 +138,12 @@ class Initializer_3D_Output(ProcessGroupInitializer):
process_group = None process_group = None
group_world_size = None group_world_size = None
mode = ParallelMode.PARALLEL_3D_OUTPUT mode = ParallelMode.PARALLEL_3D_OUTPUT
os.environ[OUTPUT_GROUP_3D] = OUTPUT_GROUP_3D env.output_group_3d = mode
for h in range(self.num_group): for h in range(self.num_group):
for i in range(self.depth): for i in range(self.depth):
for j in range(self.depth): for j in range(self.depth):
ranks = [ ranks = [h * self.depth**3 + i + self.depth * (j + self.depth * k) for k in range(self.depth)]
h * self.depth**3 + i + self.depth *
(j + self.depth * k) for k in range(self.depth)
]
group = dist.new_group(ranks) group = dist.new_group(ranks)
if self.rank in ranks: if self.rank in ranks:
...@@ -170,6 +161,7 @@ class Initializer_3D(ProcessGroupInitializer): ...@@ -170,6 +161,7 @@ class Initializer_3D(ProcessGroupInitializer):
:param args: Args used to initialize ProcessGroupInitializer :param args: Args used to initialize ProcessGroupInitializer
""" """
def __init__(self, *args): def __init__(self, *args):
super().__init__(*args) super().__init__(*args)
self.num_group = self.world_size // self.tensor_parallel_size self.num_group = self.world_size // self.tensor_parallel_size
...@@ -178,16 +170,13 @@ class Initializer_3D(ProcessGroupInitializer): ...@@ -178,16 +170,13 @@ class Initializer_3D(ProcessGroupInitializer):
f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})' f'3D depth ({self.depth}) if not cube root of tensor parallel size ({self.tensor_parallel_size})'
_check_depth_env_var(self.depth) _check_depth_env_var(self.depth)
self.input_initializer = Initializer_3D_Input(self.num_group, self.input_initializer = Initializer_3D_Input(self.num_group, self.depth, *args)
self.depth, *args) self.weight_initializer = Initializer_3D_Weight(self.num_group, self.depth, *args)
self.weight_initializer = Initializer_3D_Weight( self.output_initializer = Initializer_3D_Output(self.num_group, self.depth, *args)
self.num_group, self.depth, *args)
self.output_initializer = Initializer_3D_Output(
self.num_group, self.depth, *args)
def init_dist_group(self): def init_dist_group(self):
"""Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu. """Initialize 3D tensor parallel groups, and assign local_ranks and groups to each gpu.
:return: 3D tensor parallelism's information :return: 3D tensor parallelism's information
:rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode) :rtype: list of Tuples (local_rank, group_world_size, process_group, ranks_in_group, mode)
""" """
......
...@@ -9,4 +9,4 @@ from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler ...@@ -9,4 +9,4 @@ from ._sequence_parallel_gradient_handler import SequenceParallelGradientHandler
__all__ = ['BaseGradientHandler', 'DataParallelGradientHandler', __all__ = ['BaseGradientHandler', 'DataParallelGradientHandler',
'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler', 'ZeROGradientHandler', 'PipelineSharedModuleGradientHandler',
'MoeGradientHandler', 'SequenceParallelGradientHandler'] 'MoeGradientHandler', 'SequenceParallelGradientHandler']
\ No newline at end of file
...@@ -9,7 +9,6 @@ from typing import Iterable, Callable ...@@ -9,7 +9,6 @@ from typing import Iterable, Callable
from .._base_engine import Engine from .._base_engine import Engine
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.nn.layer import split_batch
class BaseSchedule(ABC): class BaseSchedule(ABC):
...@@ -69,7 +68,6 @@ class BaseSchedule(ABC): ...@@ -69,7 +68,6 @@ class BaseSchedule(ABC):
self.batch_size = data.size(0) self.batch_size = data.size(0)
else: else:
self.batch_size = next(iter(data.values())).size(0) self.batch_size = next(iter(data.values())).size(0)
data, label = split_batch(data), split_batch(label)
if to_gpu: if to_gpu:
return self._move_to_device(data), self._move_to_device(label) return self._move_to_device(data), self._move_to_device(label)
return data, label return data, label
......
from typing import Optional
class TensorParallelEnv(object):
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = object.__new__(cls, *args, **kwargs)
return cls._instance
def __init__(self, *args, **kwargs):
self.load(*args, **kwargs)
def load(self,
mode: Optional[str] = None,
vocab_parallel: bool = False,
parallel_input_1d: bool = False,
summa_dim: int = None,
tesseract_dim: int = None,
tesseract_dep: int = None,
depth_3d: int = None,
input_group_3d=None,
weight_group_3d=None,
output_group_3d=None):
self.mode = mode
self.vocab_parallel = vocab_parallel
self.parallel_input_1d = parallel_input_1d
self.summa_dim = summa_dim
self.tesseract_dim = tesseract_dim
self.tesseract_dep = tesseract_dep
self.depth_3d = depth_3d
self.input_group_3d = input_group_3d
self.weight_group_3d = weight_group_3d
self.output_group_3d = output_group_3d
def save(self):
return dict(mode=self.mode,
vocab_parallel=self.vocab_parallel,
parallel_input_1d=self.parallel_input_1d,
summa_dim=self.summa_dim,
tesseract_dim=self.tesseract_dim,
tesseract_dep=self.tesseract_dep,
depth_3d=self.depth_3d,
input_group_3d=self.input_group_3d,
weight_group_3d=self.weight_group_3d,
output_group_3d=self.output_group_3d)
class MoeEnv: class MoeEnv:
...@@ -33,4 +81,6 @@ class MoeEnv: ...@@ -33,4 +81,6 @@ class MoeEnv:
return self.aux_loss return self.aux_loss
tensor_parallel_env = TensorParallelEnv()
moe_env = MoeEnv() moe_env = MoeEnv()
...@@ -37,17 +37,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function): ...@@ -37,17 +37,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
input_, weight_, bias_, mean, invvar = ctx.saved_tensors input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \ grad_input, grad_weight, grad_bias \
= colossal_layer_norm_cuda.backward_affine( = colossal_layer_norm_cuda.backward_affine(
grad_output.contiguous(), mean, invvar, grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape, input_, ctx.normalized_shape,
weight_, bias_, ctx.eps) weight_, bias_, ctx.eps)
return grad_input, grad_weight, grad_bias, None, None return grad_input, grad_weight, grad_bias, None, None
class MixedFusedLayerNorm(torch.nn.Module): class MixedFusedLayerNorm(torch.nn.Module):
def __init__(self, normalized_shape, eps=1e-5): def __init__(self, normalized_shape, eps=1e-5, device=None, dtype=None):
super(MixedFusedLayerNorm, self).__init__() super(MixedFusedLayerNorm, self).__init__()
global colossal_layer_norm_cuda global colossal_layer_norm_cuda
...@@ -61,8 +61,8 @@ class MixedFusedLayerNorm(torch.nn.Module): ...@@ -61,8 +61,8 @@ class MixedFusedLayerNorm(torch.nn.Module):
normalized_shape = (normalized_shape,) normalized_shape = (normalized_shape,)
self.normalized_shape = torch.Size(normalized_shape) self.normalized_shape = torch.Size(normalized_shape)
self.eps = eps self.eps = eps
self.weight = Parameter(torch.Tensor(*normalized_shape)) self.weight = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
self.bias = Parameter(torch.Tensor(*normalized_shape)) self.bias = Parameter(torch.empty(*normalized_shape, device=device, dtype=dtype))
self.reset_parameters() self.reset_parameters()
def reset_parameters(self): def reset_parameters(self):
......
from ._utils import split_batch from ._utils import partition_batch
from .dropout import Dropout from .dropout import Dropout
from .embedding import Embedding, PatchEmbedding from .embedding import Embedding, PatchEmbedding
from .linear import Classifier, Linear from .linear import Classifier, Linear
from .normalization import LayerNorm from .normalization import LayerNorm
__all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'split_batch'] __all__ = ['Linear', 'Classifier', 'Embedding', 'PatchEmbedding', 'LayerNorm', 'Dropout', 'partition_batch']
...@@ -2,13 +2,13 @@ from torch import Tensor ...@@ -2,13 +2,13 @@ from torch import Tensor
from ..parallel_2d._operation import split_tensor_2d from ..parallel_2d._operation import split_tensor_2d
from ..parallel_2p5d._operation import split_tensor_2p5d from ..parallel_2p5d._operation import split_tensor_2p5d
from ..parallel_3d._operation import split_tensor_3d from ..parallel_3d._operation import split_batch_3d
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
_parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_tensor_3d} _parallel_split_batch = {'2d': split_tensor_2d, '2.5d': split_tensor_2p5d, '3d': split_batch_3d}
def split_batch(input_) -> Tensor: def partition_batch(input_) -> Tensor:
tensor_parallel_mode = get_tensor_parallel_mode() tensor_parallel_mode = get_tensor_parallel_mode()
if tensor_parallel_mode in _parallel_split_batch: if tensor_parallel_mode in _parallel_split_batch:
if isinstance(input_, dict): if isinstance(input_, dict):
......
from contextlib import nullcontext
import torch.nn as nn import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.utils import conditional_context
from ..parallel_1d import * from ..parallel_1d import *
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
...@@ -26,6 +23,8 @@ class Dropout(nn.Module): ...@@ -26,6 +23,8 @@ class Dropout(nn.Module):
self.drop = nn.Dropout(p, inplace) self.drop = nn.Dropout(p, inplace)
def forward(self, *args): def forward(self, *args):
cm = nullcontext() if self.tensor_parallel in ['None', '1d'] else seed(ParallelMode.TENSOR) if self.tensor_parallel in [None, '1d']:
with cm:
return self.drop(*args) return self.drop(*args)
else:
with seed(ParallelMode.TENSOR):
return self.drop(*args)
import math import math
from typing import Callable, Optional from typing import Callable
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import dtype, nn from torch import dtype, nn
...@@ -12,10 +12,21 @@ from ..parallel_3d import * ...@@ -12,10 +12,21 @@ from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
from ..vanilla import * from ..vanilla import *
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D} _parallel_embedding = {
'2d': Embedding2D,
'2.5d': Embedding2p5D,
'3d': Embedding3D,
}
_vocab_parallel_embedding = {
'1d': VocabParallelEmbedding1D,
'2d': VocabParallelEmbedding2D,
'2.5d': VocabParallelEmbedding2p5D,
'3d': VocabParallelEmbedding3D
}
_parallel_patchembedding = { _parallel_patchembedding = {
'None': VanillaPatchEmbedding, None: VanillaPatchEmbedding,
'1d': VanillaPatchEmbedding, '1d': VanillaPatchEmbedding,
'2d': PatchEmbedding2D, '2d': PatchEmbedding2D,
'2.5d': PatchEmbedding2p5D, '2.5d': PatchEmbedding2p5D,
...@@ -40,26 +51,23 @@ class Embedding(nn.Module): ...@@ -40,26 +51,23 @@ class Embedding(nn.Module):
:param args: Args used in F.embedding :param args: Args used in F.embedding
:param kwargs: Kwargs used in F.embedding :param kwargs: Kwargs used in F.embedding
""" """
def __init__(self, def __init__(self,
num_embeddings: int, num_embeddings: int,
embedding_dim: int, embedding_dim: int,
padding_idx: int = None, padding_idx: int = None,
dtype: dtype = None, dtype: dtype = None,
weight_initializer: Callable = init.normal_(), weight_initializer: Callable = init.normal_(),
vocab_parallel_limit: int = 2048,
*args, *args,
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None': if tensor_parallel is None or (tensor_parallel == '1d' and num_embeddings <= vocab_parallel_limit):
self.embed = nn.Embedding(num_embeddings, self.embed = nn.Embedding(num_embeddings, embedding_dim, padding_idx=padding_idx, *args,
embedding_dim, **kwargs).to(dtype).to(get_current_device())
padding_idx=padding_idx,
device=get_current_device(),
dtype=dtype,
*args,
**kwargs)
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim) weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
else: elif num_embeddings <= vocab_parallel_limit:
self.embed = _parallel_embedding[tensor_parallel]( self.embed = _parallel_embedding[tensor_parallel](
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
...@@ -69,6 +77,16 @@ class Embedding(nn.Module): ...@@ -69,6 +77,16 @@ class Embedding(nn.Module):
*args, *args,
**kwargs, **kwargs,
) )
else:
self.embed = _vocab_parallel_embedding[tensor_parallel](
num_embeddings,
embedding_dim,
padding_idx=padding_idx,
dtype=dtype,
weight_initializer=weight_initializer,
*args,
**kwargs,
)
@property @property
def weight(self): def weight(self):
...@@ -101,16 +119,19 @@ class PatchEmbedding(nn.Module): ...@@ -101,16 +119,19 @@ class PatchEmbedding(nn.Module):
:param position_embed_initializer: The intializer of position embedding, defaults to zero :param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional :type position_embed_initializer: typing.Callable, optional
""" """
def __init__(self,
img_size: int, def __init__(
patch_size: int, self,
in_chans: int, img_size: int,
embed_size: int, patch_size: int,
dtype: dtype = None, in_chans: int,
flatten: bool = True, embed_size: int,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), dtype: dtype = None,
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1), flatten: bool = True,
position_embed_initializer: Callable = init.zeros_()) -> None: weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()
) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
self.embed = _parallel_patchembedding[tensor_parallel]( self.embed = _parallel_patchembedding[tensor_parallel](
......
import math import math
from typing import Callable, Optional from typing import Callable
from colossalai.nn.layer.parallel_1d.layers import Classifier1D
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import dtype, nn from torch import dtype, nn
...@@ -16,13 +15,20 @@ from ..vanilla import * ...@@ -16,13 +15,20 @@ from ..vanilla import *
_parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D} _parallel_linear = {'1d': Linear1D, '2d': Linear2D, '2.5d': Linear2p5D, '3d': Linear3D}
_parallel_classifier = { _parallel_classifier = {
'None': VanillaClassifier, None: VanillaClassifier,
'1d': Classifier1D, '1d': Classifier1D,
'2d': Classifier2D, '2d': Classifier2D,
'2.5d': Classifier2p5D, '2.5d': Classifier2p5D,
'3d': Classifier3D '3d': Classifier3D
} }
_vocab_parallel_classifier = {
'1d': VocabParallelClassifier1D,
'2d': VocabParallelClassifier2D,
'2.5d': VocabParallelClassifier2p5D,
'3d': VocabParallelClassifier3D
}
class Linear(nn.Module): class Linear(nn.Module):
""" """
...@@ -40,8 +46,9 @@ class Linear(nn.Module): ...@@ -40,8 +46,9 @@ class Linear(nn.Module):
:type weight_initializer: typing.Callable, optional :type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
:param kwargs: Kwargs used for initialization :param kwargs: Kwargs used for particular parallelisms
""" """
def __init__(self, def __init__(self,
in_features: int, in_features: int,
out_features: int, out_features: int,
...@@ -52,10 +59,10 @@ class Linear(nn.Module): ...@@ -52,10 +59,10 @@ class Linear(nn.Module):
**kwargs) -> None: **kwargs) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel == 'None': if tensor_parallel is None:
self.layer = nn.Linear(in_features, out_features, bias=bias, device=get_current_device(), dtype=dtype) self.layer = nn.Linear(in_features, out_features, bias=bias).to(dtype).to(get_current_device())
weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features) weight_initializer(self.layer.weight, fan_in=in_features, fan_out=out_features)
if bias: if self.layer.bias is not None:
bias_initializer(self.layer.bias, fan_in=in_features) bias_initializer(self.layer.bias, fan_in=in_features)
else: else:
self.layer = _parallel_linear[tensor_parallel]( self.layer = _parallel_linear[tensor_parallel](
...@@ -97,26 +104,38 @@ class Classifier(nn.Module): ...@@ -97,26 +104,38 @@ class Classifier(nn.Module):
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer :param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional :type bias_initializer: typing.Callable, optional
""" """
def __init__(
self, def __init__(self,
in_features: int, in_features: int,
num_classes: int, num_classes: int,
weight: nn.Parameter = None, weight: nn.Parameter = None,
bias: bool = True, bias: bool = True,
dtype: dtype = None, dtype: dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)), weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1) bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
) -> None: vocab_parallel_limit: int = 2048) -> None:
super().__init__() super().__init__()
self.layer = _parallel_classifier[get_tensor_parallel_mode()]( tensor_parallel = get_tensor_parallel_mode()
in_features, if num_classes <= vocab_parallel_limit or tensor_parallel is None:
num_classes, self.layer = _parallel_classifier[tensor_parallel](
weight=weight, in_features,
bias=bias, num_classes,
dtype=dtype, weight=weight,
weight_initializer=weight_initializer, bias=bias,
bias_initializer=bias_initializer, dtype=dtype,
) weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
else:
self.layer = _vocab_parallel_classifier[tensor_parallel](
in_features,
num_classes,
weight=weight,
bias=bias,
dtype=dtype,
weight_initializer=weight_initializer,
bias_initializer=bias_initializer,
)
@property @property
def weight(self): def weight(self):
......
from typing import Optional
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from torch import nn from torch import nn
from colossalai import kernel
from ... import init as init from ... import init as init
from ..parallel_1d import * from ..parallel_1d import *
...@@ -11,7 +10,12 @@ from ..parallel_3d import * ...@@ -11,7 +10,12 @@ from ..parallel_3d import *
from ..utils import get_tensor_parallel_mode from ..utils import get_tensor_parallel_mode
from ..vanilla import * from ..vanilla import *
_parallel_layernorm = {'2d': LayerNorm2D, '2.5d': LayerNorm2p5D, '3d': LayerNorm3D} _parallel_layernorm = {
'1d': kernel.LayerNorm,
'2d': LayerNorm2D,
'2.5d': LayerNorm2p5D,
'3d': LayerNorm3D
}
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
...@@ -28,11 +32,12 @@ class LayerNorm(nn.Module): ...@@ -28,11 +32,12 @@ class LayerNorm(nn.Module):
:param dtype: The dtype of parameters, defaults to None :param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional :type dtype: torch.dtype, optional
""" """
def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None: def __init__(self, normalized_shape: int, eps=1e-05, dtype=None) -> None:
super().__init__() super().__init__()
tensor_parallel = get_tensor_parallel_mode() tensor_parallel = get_tensor_parallel_mode()
if tensor_parallel in ['None', '1d']: if tensor_parallel is None:
self.norm = nn.LayerNorm(normalized_shape, eps=eps, device=get_current_device(), dtype=dtype) self.norm = nn.LayerNorm(normalized_shape, eps=eps).to(dtype).to(get_current_device())
else: else:
self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype) self.norm = _parallel_layernorm[tensor_parallel](normalized_shape, eps=eps, dtype=dtype)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment