Unverified Commit da01c234 authored by Frank Lee's avatar Frank Lee Committed by GitHub
Browse files

Develop/experiments (#59)



* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>

* Split conv2d, class token, positional embedding in 2d, Fix random number in ddp
Fix convergence in cifar10, Imagenet1000

* Integrate 1d tensor parallel in Colossal-AI (#39)

* fixed 1D and 2D convergence (#38)

* optimized 2D operations

* fixed 1D ViT convergence problem

* Feature/ddp (#49)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* support torch ddp

* fix loss accumulation

* add log for ddp

* change seed

* modify timing hook
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* Feature/pipeline (#40)

* remove redundancy func in setup (#19) (#20)

* use env to control the language of doc (#24) (#25)

* Support TP-compatible Torch AMP and Update trainer API (#27)

* Add gradient accumulation, fix lr scheduler

* fix FP16 optimizer and adapted torch amp with tensor parallel (#18)

* fixed bugs in compatibility between torch amp and tensor parallel and performed some minor fixes

* fixed trainer

* Revert "fixed trainer"

This reverts commit 2e0b0b76990e8d4e337add483d878c0f61cf5097.

* improved consistency between trainer, engine and schedule (#23)
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>

* add an example of ViT-B/16 and remove w_norm clipping in LAMB (#29)

* add explanation for ViT example (#35) (#36)

* optimize communication of pipeline parallel

* fix grad clip for pipeline
Co-authored-by: default avatarFrank Lee <somerlee.9@gmail.com>
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>

* optimized 3d layer to fix slow computation ; tested imagenet performance with 3d; reworked lr_scheduler config definition; fixed launch args; fixed some printing issues; simplified apis of 3d layers (#51)

* Update 2.5d layer code to get a similar accuracy on imagenet-1k dataset

* update api for better usability (#58)

update api for better usability
Co-authored-by: default avatar1SAA <c2h214748@gmail.com>
Co-authored-by: default avatarver217 <lhx0217@gmail.com>
Co-authored-by: default avatarpuck_WCR <46049915+WANG-CR@users.noreply.github.com>
Co-authored-by: default avatarbinmakeswell <binmakeswell@gmail.com>
Co-authored-by: default avatarアマデウス <kurisusnowdeng@users.noreply.github.com>
Co-authored-by: default avatarBoxiangW <45734921+BoxiangW@users.noreply.github.com>
parent eb2f8b1f
from .initialize import init_dist, initialize from .initialize import (initialize, launch, launch_from_openmpi,
from .nn import * launch_from_slurm, launch_from_torch, get_default_parser)
__version__ = '0.0.1' __version__ = '0.0.1'
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from .amp_type import AMP_TYPE
from colossalai.context import Config
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from .torch_amp import convert_to_torch_amp
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
def convert_to_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
mode: AMP_TYPE,
amp_config: Config = None):
assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
if amp_config is None:
amp_config = Config()
if mode == AMP_TYPE.TORCH:
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
elif mode == AMP_TYPE.APEX:
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
elif mode == AMP_TYPE.NAIVE:
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
return model, optimizer, criterion
...@@ -7,4 +7,4 @@ from enum import Enum ...@@ -7,4 +7,4 @@ from enum import Enum
class AMP_TYPE(Enum): class AMP_TYPE(Enum):
APEX = 'apex' APEX = 'apex'
TORCH = 'torch' TORCH = 'torch'
PARALLEL = 'parallel' NAIVE = 'naive'
from .apex_amp import ApexAMPOptimizer
import torch.nn as nn
from torch.optim import Optimizer
import apex.amp as apex_amp
def convert_to_apex_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
model, optimizer = apex_amp.initialize(model, optimizer, **amp_config)
optimizer = ApexAMPOptimizer(optimizer)
return model, optimizer
__all__ = ['convert_to_apex_amp', 'ApexAMPOptimizer']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
try:
import apex.amp as apex_amp
except:
pass
from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(ColossalaiOptimizer):
def backward(self, loss: Tensor):
with apex_amp.scale_loss(loss, self.optim) as scaled_loss:
scaled_loss.backward()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0:
clip_grad_norm_fp32(apex_amp.master_params(self.optim), max_norm)
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.utils import is_no_pp_or_last_stage
from .naive_amp import NaiveAMPOptimizer, NaiveAMPModel
def convert_to_naive_amp(model: nn.Module,
optimizer: Optimizer,
amp_config):
if is_no_pp_or_last_stage():
model = NaiveAMPModel(model, output_to_fp32=True)
else:
model = NaiveAMPModel(model, output_to_fp32=False)
optimizer = NaiveAMPOptimizer(optimizer, **amp_config)
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer']
...@@ -12,11 +12,9 @@ from torch.optim import Optimizer ...@@ -12,11 +12,9 @@ from torch.optim import Optimizer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_dist_logger
from colossalai.registry import OPTIMIZER_WRAPPERS from colossalai.utils import (print_rank_0, copy_tensor_parallel_attributes,
from colossalai.utils import print_rank_0 clip_grad_norm_fp32, count_zeros_fp32, multi_tensor_applier)
from ._utils import copy_tensor_parallel_attributes, clip_grad_norm_fp32, count_zeros_fp32
from ..multi_tensor_apply import multi_tensor_applier
def _zero_grad_group_helper(group, set_to_none): def _zero_grad_group_helper(group, set_to_none):
...@@ -92,7 +90,7 @@ class DynamicGradScaler: ...@@ -92,7 +90,7 @@ class DynamicGradScaler:
self._growth_tracker = 0 self._growth_tracker = 0
self._hysteresis_tracker = self.hysteresis self._hysteresis_tracker = self.hysteresis
self._logger = get_global_dist_logger() self._logger = get_dist_logger()
@property @property
def scale(self): def scale(self):
...@@ -113,7 +111,7 @@ class DynamicGradScaler: ...@@ -113,7 +111,7 @@ class DynamicGradScaler:
if self._hysteresis_tracker <= 0: if self._hysteresis_tracker <= 0:
self._scale = torch.max(self._scale * self.backoff_factor, self._scale = torch.max(self._scale * self.backoff_factor,
self.min_scale) self.min_scale)
self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}') self._logger.info(f'overflow occurs, loss scale is adjusted to {self._scale}', ranks=[0])
else: else:
# If there is no nan/inf, increment the growth tracker. # If there is no nan/inf, increment the growth tracker.
self._growth_tracker += 1 self._growth_tracker += 1
...@@ -125,10 +123,10 @@ class DynamicGradScaler: ...@@ -125,10 +123,10 @@ class DynamicGradScaler:
# and scale up the loss scale. # and scale up the loss scale.
if self._max_scale is not None and self._scale >= self._max_scale: if self._max_scale is not None and self._scale >= self._max_scale:
self._logger.info( self._logger.info(
f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed') f'Current loss scale {self._scale} has reached the max scale {self._max_scale} allowed', ranks=[0])
else: else:
self._scale = self._scale * self.growth_factor self._scale = self._scale * self.growth_factor
self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}') self._logger.info(f'no consecutive overflow, loss scale is adjusted to {self._scale}', ranks=[0])
def state_dict(self): def state_dict(self):
state_dict = {} state_dict = {}
...@@ -145,7 +143,6 @@ class DynamicGradScaler: ...@@ -145,7 +143,6 @@ class DynamicGradScaler:
self._max_scale = state_dict['max_scale'] self._max_scale = state_dict['max_scale']
@OPTIMIZER_WRAPPERS.register_module
class FP16Optimizer(Optimizer): class FP16Optimizer(Optimizer):
"""Float16 optimizer for fp16 and bf16 data types. """Float16 optimizer for fp16 and bf16 data types.
...@@ -184,13 +181,13 @@ class FP16Optimizer(Optimizer): ...@@ -184,13 +181,13 @@ class FP16Optimizer(Optimizer):
max_scale: int = 2 ** 32): max_scale: int = 2 ** 32):
# default args for compatibility # default args for compatibility
bf16 = False bf16 = False
params_have_main_grad = False params_have_main_grad = True
# have a defaults for compatibility with pytorch optim # have a defaults for compatibility with pytorch optim
self.defaults = optimizer.defaults self.defaults = optimizer.defaults
# log config # log config
self._logger = get_global_dist_logger() self._logger = get_dist_logger()
self._logger.info(f"\n========= FP16 Optimizer Config =========\n" self._logger.info(f"\n========= FP16 Optimizer Config =========\n"
f"Optimizer: {optimizer.__class__.__name__}\n" f"Optimizer: {optimizer.__class__.__name__}\n"
f"clip_grad = {clip_grad}\n" f"clip_grad = {clip_grad}\n"
...@@ -328,6 +325,7 @@ class FP16Optimizer(Optimizer): ...@@ -328,6 +325,7 @@ class FP16Optimizer(Optimizer):
else: else:
if model_param.grad is not None: if model_param.grad is not None:
main_param.grad = model_param.grad.float() main_param.grad = model_param.grad.float()
# For fp32 grads, we need to reset the grads to main grad. # For fp32 grads, we need to reset the grads to main grad.
if self.params_have_main_grad: if self.params_have_main_grad:
for model_group in self.fp32_from_fp32_groups: for model_group in self.fp32_from_fp32_groups:
...@@ -387,10 +385,6 @@ class FP16Optimizer(Optimizer): ...@@ -387,10 +385,6 @@ class FP16Optimizer(Optimizer):
@torch.no_grad() @torch.no_grad()
def step(self): def step(self):
# for param_group in self.float16_groups:
# for param in param_group:
# print(param.grad is None)
# Copy gradients from model params to main params. # Copy gradients from model params to main params.
self._copy_model_grads_to_main_grads() self._copy_model_grads_to_main_grads()
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn as nn
from torch import Tensor
from typing import Union, List, Any, Dict
from torch.optim import Optimizer
import torch.cuda.amp as torch_amp
from colossalai.nn.optimizer import ColossalaiOptimizer
from ._fp16_optimizer import FP16Optimizer
class NaiveAMPOptimizer(ColossalaiOptimizer):
def __init__(self, optim: Optimizer, *args, **kwargs):
optim = FP16Optimizer(optimizer=optim, *args, **kwargs)
super().__init__(optim)
def backward(self, loss: Tensor):
loss = self.optim.scale_loss(loss)
loss.backward()
def step(self):
self.optim.step()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
pass
class NaiveAMPModel(nn.Module):
def __init__(self,
model: nn.Module,
output_to_fp32: bool = True):
super().__init__()
self.model = model.half()
self._output_to_fp32 = output_to_fp32
def _convert_to_fp16(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float32:
input_ = input_.half()
return input_
def _convert_to_fp32(self, input_: Any):
if isinstance(input_, Tensor) and input_.dtype == torch.float16:
input_ = input_.float()
return input_
def forward(self, *args, **kwargs):
if args:
args = [self._convert_to_fp16(arg) for arg in args]
if kwargs:
for k, v in kwargs.items():
kwargs[k] = self._convert_to_fp16(v)
out = self.model(*args, **kwargs)
if self._output_to_fp32:
if isinstance(out, Tensor):
out = self._convert_to_fp32(out)
elif isinstance(out, (tuple, list)):
out = [self._convert_to_fp32(val) for val in out]
return out
import torch.nn as nn
from torch.optim import Optimizer
from torch.nn.modules.loss import _Loss
from colossalai.context import Config
from .torch_amp import TorchAMPOptimizer, TorchAMPModel, TorchAMPLoss
def convert_to_torch_amp(model: nn.Module,
optimizer: Optimizer,
criterion: _Loss,
amp_config: Config):
model = TorchAMPModel(model)
optimizer = TorchAMPOptimizer(optimizer, **amp_config)
criterion = TorchAMPLoss(criterion)
return model, optimizer, criterion
__all__ = ['convert_to_torch_amp', 'TorchAMPModel', 'TorchAMPLoss', 'TorchAMPOptimizer']
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.p #!/usr/bin/env python
# -*- encoding: utf-8 -*-
# modified from https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py
# to support tensor parallel
import torch import torch
from collections import defaultdict, abc from collections import defaultdict, abc
import warnings import warnings
......
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
import torch.cuda.amp as torch_amp
from torch import Tensor
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from ._grad_scaler import GradScaler
from colossalai.nn.optimizer import ColossalaiOptimizer
from colossalai.utils import clip_grad_norm_fp32
class TorchAMPOptimizer(ColossalaiOptimizer):
def __init__(self, optim: Optimizer, *args, **kwargs):
super().__init__(optim)
self.scaler = GradScaler(*args, **kwargs)
def backward(self, loss: Tensor):
self.scaler.scale(loss).backward()
def step(self):
self.scaler.step(self.optim)
self.scaler.update()
def clip_grad_norm(self, model: nn.Module, max_norm: float):
if max_norm > 0.0:
self.scaler.unscale_(self.optim)
clip_grad_norm_fp32(model.parameters(), max_norm)
class TorchAMPModel(nn.Module):
def __init__(self, model: nn.Module) -> None:
super().__init__()
self.model = model
@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.model(*args, **kwargs)
class TorchAMPLoss(nn.Module):
def __init__(self, loss: _Loss):
super().__init__()
self.loss = loss
@torch_amp.autocast()
def forward(self, *args, **kwargs):
return self.loss(*args, **kwargs)
from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_optimizer_wrapper, from .builder import (build_schedule, build_lr_scheduler, build_model, build_optimizer, build_layer,
build_layer, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler, build_loss, build_hooks, build_dataset, build_transform, build_data_sampler,
build_gradient_handler) build_gradient_handler)
from .pipeline import ModelInitializer from .pipeline import PipelineModelInitializer
__all__ = [ __all__ = [
'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer', 'build_optimizer_wrapper', 'build_schedule', 'build_lr_scheduler', 'build_model', 'build_optimizer',
'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler', 'build_layer', 'build_loss', 'build_hooks', 'build_dataset', 'build_transform', 'build_data_sampler',
'build_gradient_handler', 'ModelInitializer' 'build_gradient_handler', 'PipelineModelInitializer'
] ]
...@@ -106,7 +106,7 @@ def build_dataset(config): ...@@ -106,7 +106,7 @@ def build_dataset(config):
return build_from_registry(config, DATASETS) return build_from_registry(config, DATASETS)
def build_optimizer(config, model, params: Iterable = None, need_module=False): def build_optimizer(config, model):
"""Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`, """Returns an optimizer object of :class:`torch.optim.Optimizer` constructed from `config`,
'model' and 'params'. 'model' and 'params'.
...@@ -115,23 +115,12 @@ def build_optimizer(config, model, params: Iterable = None, need_module=False): ...@@ -115,23 +115,12 @@ def build_optimizer(config, model, params: Iterable = None, need_module=False):
:type config: dict or :class:`colossalai.context.Config` :type config: dict or :class:`colossalai.context.Config`
:param model: A model containing parameters for the optimizer :param model: A model containing parameters for the optimizer
:type model: :class:`nn.Module` :type model: :class:`nn.Module`
:param params: A dict containing parameters for the optimizer
:type params: dict, optional
:param need_module: Indicates whether the optimizer needs a module
:type params: bool, optional
:raises AssertionError: Raises an AssertionError if both `model` and `params` are None
:return: An object of :class:`torch.optim.Optimizer` :return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer` :rtype: :class:`torch.optim.Optimizer`
""" """
assert model is not None or params is not None, 'arguments model and params can not both be None' config_ = config.copy()
if need_module: config_['params'] = model.parameters()
config['module'] = model return build_from_registry(config_, OPTIMIZERS)
elif model is not None:
config['params'] = model.parameters()
elif params is not None:
config['params'] = params
return build_from_registry(config, OPTIMIZERS)
def build_gradient_handler(config, model, optimizer): def build_gradient_handler(config, model, optimizer):
...@@ -149,8 +138,9 @@ def build_gradient_handler(config, model, optimizer): ...@@ -149,8 +138,9 @@ def build_gradient_handler(config, model, optimizer):
:rtype: :class:`BaseGradientHandler` :rtype: :class:`BaseGradientHandler`
""" """
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') config_['model'] = model
return GRADIENT_HANDLER.get_module(mod_type)(model, optimizer, **config_) config_['optimizer'] = optimizer
return build_from_registry(config_, GRADIENT_HANDLER)
def build_hooks(config, trainer): def build_hooks(config, trainer):
...@@ -164,8 +154,9 @@ def build_hooks(config, trainer): ...@@ -164,8 +154,9 @@ def build_hooks(config, trainer):
:return: An object of :class:`BaseHook` :return: An object of :class:`BaseHook`
:rtype: :class:`BaseHook` :rtype: :class:`BaseHook`
""" """
config['trainer'] = trainer config_ = config.copy()
return build_from_registry(config, HOOKS) config_['trainer'] = trainer
return build_from_registry(config_, HOOKS)
def build_transform(config): def build_transform(config):
...@@ -195,32 +186,8 @@ def build_data_sampler(config, dataset): ...@@ -195,32 +186,8 @@ def build_data_sampler(config, dataset):
:rtype: :class:`colossalai.nn.data.sampler.BaseSampler` :rtype: :class:`colossalai.nn.data.sampler.BaseSampler`
""" """
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') config_['dataset'] = dataset
return SAMPLERS.get_module(mod_type)(dataset, **config_) return build_from_registry(config_, DATA_SAMPLERS)
def build_optimizer_wrapper(config, optimizer, model=None):
"""Returns an optimizer wrapper object of :class:`torch.optim.Optimizer` constructed
from `config`, `model` and `optimizer`.
:param config: A python dict or a :class:`colossalai.context.Config` object
containing information used in the construction of the return object
:type config: dict or :class:`colossalai.context.Config`
:param optimizer: An optimizer object containing parameters for the gradient handler
:type optimizer: :class:`torch.optim.Optimizer`
:param model: A model containing parameters for the gradient handler
:type model: :class:`nn.Module`, optional
:return: An object of :class:`torch.optim.Optimizer`
:rtype: :class:`torch.optim.Optimizer`
"""
config_ = config.copy()
mod_type = config_.pop('type')
# LSG: special treatment for zeor level 3
if mod_type == 'ZeroRedundancyOptimizer_Level_3':
return OPTIMIZER_WRAPPERS.get_module(mod_type)(model, optimizer, **config_)
else:
return OPTIMIZER_WRAPPERS.get_module(mod_type)(optimizer, **config_)
def build_lr_scheduler(config, optimizer): def build_lr_scheduler(config, optimizer):
...@@ -241,8 +208,8 @@ def build_lr_scheduler(config, optimizer): ...@@ -241,8 +208,8 @@ def build_lr_scheduler(config, optimizer):
:rtype: :class:`torch.optim.lr_scheduler` :rtype: :class:`torch.optim.lr_scheduler`
""" """
config_ = config.copy() config_ = config.copy()
mod_type = config_.pop('type') config_['optimizer'] = optimizer
return LR_SCHEDULERS.get_module(mod_type)(optimizer, **config_) return build_from_registry(config_, LR_SCHEDULERS)
def build_schedule(config): def build_schedule(config):
......
...@@ -4,7 +4,7 @@ import heapq ...@@ -4,7 +4,7 @@ import heapq
from colossalai.builder import build_model, build_layer from colossalai.builder import build_model, build_layer
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.logging import get_global_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import set_to_cuda from colossalai.utils import set_to_cuda
...@@ -111,21 +111,21 @@ def _binary_search(weights, num): ...@@ -111,21 +111,21 @@ def _binary_search(weights, num):
return intervals return intervals
def _partition_uniform(num_items, num_parts, num_chunks): def _partition_uniform(num_items, pipeline_parallel_size, num_chunks):
assert num_items % num_chunks == 0, \ assert num_items % num_chunks == 0, \
"Layer length should be divided by the number of chunks, otherwise parameter method is recomended" "Layer length should be divided by the number of chunks, otherwise parameter method is recomended"
logger = get_global_dist_logger() logger = get_dist_logger()
parts = [[] for _ in range(num_parts)] parts = [[] for _ in range(pipeline_parallel_size)]
partition_items = num_items // num_chunks partition_items = num_items // num_chunks
for idx in range(num_chunks): for idx in range(num_chunks):
base_idx = idx * partition_items base_idx = idx * partition_items
chunk_size = partition_items // num_parts chunk_size = partition_items // pipeline_parallel_size
left = num_parts - partition_items % num_parts left = pipeline_parallel_size - partition_items % pipeline_parallel_size
if chunk_size == 0: if chunk_size == 0:
logger.warning("Some nodes in Pipeline have no requests") logger.warning("Some nodes in Pipeline have no requests")
for p in range(num_parts): for p in range(pipeline_parallel_size):
st = base_idx st = base_idx
base_idx += chunk_size + (p >= left) base_idx += chunk_size + (p >= left)
parts[p].append((st, base_idx)) parts[p].append((st, base_idx))
...@@ -133,34 +133,34 @@ def _partition_uniform(num_items, num_parts, num_chunks): ...@@ -133,34 +133,34 @@ def _partition_uniform(num_items, num_parts, num_chunks):
return parts return parts
def _partition_balanced(weights, num_parts, num_chunks): def _partition_balanced(weights, pipeline_parallel_size, num_chunks):
num_total = num_parts * num_chunks num_total = pipeline_parallel_size * num_chunks
num_items = len(weights) num_items = len(weights)
if num_items <= num_total: if num_items <= num_total:
return _partition_uniform(num_items, num_parts, num_chunks) return _partition_uniform(num_items, pipeline_parallel_size, num_chunks)
intervals = _binary_search(weights, num_total) intervals = _binary_search(weights, num_total)
current = 0 current = 0
parts = [[] for _ in range(num_parts)] parts = [[] for _ in range(pipeline_parallel_size)]
for inter in intervals: for inter in intervals:
parts[current].append(inter) parts[current].append(inter)
current = (current + 1) % num_parts current = (current + 1) % pipeline_parallel_size
return parts return parts
class ModelInitializer(): class PipelineModelInitializer():
def __init__(self, config, num_chunks, verbose=False): def __init__(self, config, num_chunks, verbose=False):
self.num_chunks = num_chunks self.num_chunks = num_chunks
self.ori_model = build_model(config) self.ori_model = build_model(config)
self.layers = self.ori_model.layers_cfg self.layers = self.ori_model.layers_cfg
layer_length = len(self.layers) layer_length = len(self.layers)
self.verbose = verbose self.verbose = verbose
self._logger = get_global_dist_logger() self._logger = get_dist_logger()
self._logger.info(f"The total length of layers is {layer_length}", ranks=[0]) self._logger.info(f"The total length of layers is {layer_length}", ranks=[0])
def model_initialize(self, partition_method='parameter'): def initialize(self, partition_method='parameter'):
# Some space for initializing comunication groups # Some space for initializing comunication groups
self._interval = None self._interval = None
self._partition_layers(method=partition_method) self._partition_layers(method=partition_method)
...@@ -198,7 +198,7 @@ class ModelInitializer(): ...@@ -198,7 +198,7 @@ class ModelInitializer():
for st, ed in self.parts[stage]: for st, ed in self.parts[stage]:
for idx, layer in enumerate(self.layers[st: ed]): for idx, layer in enumerate(self.layers[st: ed]):
log_str += f'\t{idx + st:2d}: {layer}\n' log_str += f'\t{idx + st:2d}: {layer}\n'
self._logger.info(log_str) self._logger.info(log_str, ranks=[0])
# Save the partition # Save the partition
self._interval = self.parts[pipeline_rank] self._interval = self.parts[pipeline_rank]
......
from .collective import all_gather, reduce_scatter, scatter from .collective import all_gather, reduce_scatter, all_reduce
from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward, from .p2p import (send_forward, send_forward_recv_forward, send_backward_recv_forward,
send_backward, send_backward_recv_backward, send_forward_recv_backward, send_backward, send_backward_recv_backward, send_forward_recv_backward,
send_forward_backward_recv_forward_backward, recv_forward, recv_backward) send_forward_backward_recv_forward_backward, recv_forward, recv_backward)
...@@ -6,7 +6,7 @@ from .ring import ring_forward ...@@ -6,7 +6,7 @@ from .ring import ring_forward
from .utils import send_tensor_meta, recv_tensor_meta from .utils import send_tensor_meta, recv_tensor_meta
__all__ = [ __all__ = [
'all_gather', 'reduce_scatter', 'scatter', 'all_gather', 'reduce_scatter', 'all_reduce',
'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward', 'send_forward', 'send_forward_recv_forward', 'send_forward_backward_recv_forward_backward',
'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward', 'send_backward', 'send_backward_recv_backward', 'send_backward_recv_forward',
'send_forward_recv_backward', 'recv_backward', 'recv_forward', 'send_forward_recv_backward', 'recv_backward', 'recv_forward',
......
...@@ -11,7 +11,7 @@ from colossalai.utils import get_current_device ...@@ -11,7 +11,7 @@ from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int, def all_gather(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a """Gathers all tensors from the parallel group and concatenates them in a
specific dimension. specific dimension.
...@@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int, ...@@ -26,18 +26,28 @@ def all_gather(tensor: Tensor, dim: int,
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone() temp = tensor.clone()
shape = list(temp.shape) # shape = list(temp.shape)
shape[dim] *= depth # shape[dim] *= depth
out = torch.empty(shape, dtype=temp.dtype, device=get_current_device()) # out = torch.zeros(shape, dtype=temp.dtype, device=get_current_device())
out = list(torch.chunk(out, depth, dim=dim)) # out = list(torch.chunk(out, depth, dim=dim))
out = [val.contiguous() for val in out] # out = [val.contiguous() for val in out]
dist.all_gather(out, temp, group=gpc.get_group(parallel_mode)) shape = [1] * len(tensor.shape)
out = torch.cat(out, dim=dim) shape[dim] = depth
return out out = tensor.repeat(shape)
out = list(map(lambda x: x.contiguous(), torch.chunk(out, depth, dim=dim)))
op = dist.all_gather(tensor_list=out,
tensor=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
# out = torch.cat(out, dim=dim)
if async_op:
return out, op
else:
return out
def reduce_scatter(tensor: Tensor, dim: int, def reduce_scatter(tensor: Tensor, dim: int,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode, async_op=False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all """Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group. members in the parallel group.
...@@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int, ...@@ -51,34 +61,52 @@ def reduce_scatter(tensor: Tensor, dim: int,
:rtype: Tensor :rtype: Tensor
""" """
depth = gpc.get_world_size(parallel_mode) depth = gpc.get_world_size(parallel_mode)
temp = list(torch.chunk(tensor, depth, dim=dim)) # temp = list(torch.chunk(tensor, depth, dim=dim))
temp = [val.contiguous() for val in temp] # temp = [val.contiguous() for val in temp]
out = torch.empty(temp[0].shape, # out = torch.zeros(temp[0].shape,
dtype=temp[0].dtype, # dtype=temp[0].dtype,
device=get_current_device()) # device=get_current_device())
dist.reduce_scatter(output=out, temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
input_list=temp, out = temp[0].clone()
group=gpc.get_group(parallel_mode)) op = dist.reduce_scatter(output=out,
return out input_list=temp,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return out, op
else:
return out
def scatter(tensor: Tensor, src: int, dim: int, def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode) -> Tensor: parallel_mode: ParallelMode,
"""Scatters in a specific dimension from source rank to all ranks in async_op=False) -> Tensor:
the parallel group. op = dist.all_reduce(tensor,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return tensor, op
else:
return tensor
# def scatter(tensor: Tensor, src: int, dim: int,
# parallel_mode: ParallelMode) -> Tensor:
# """Scatters in a specific dimension from source rank to all ranks in
# the parallel group.
:param tensor: Tensor to be scattered # :param tensor: Tensor to be scattered
:param dim: The dimension scattering in # :param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication # :param parallel_mode: Parallel group mode used in this communication
:type tensor: Tensor # :type tensor: Tensor
:type dim: int # :type dim: int
:type parallel_mode: ParallelMode # :type parallel_mode: ParallelMode
:return: The tensor generated by scatter # :return: The tensor generated by scatter
:rtype: Tensor # :rtype: Tensor
""" # """
depth = gpc.get_world_size(parallel_mode) # depth = gpc.get_world_size(parallel_mode)
temp = tensor.clone() # temp = tensor.clone()
dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode)) # dist.broadcast(temp, src=src, group=gpc.get_group(parallel_mode))
rank = gpc.get_local_rank(parallel_mode) # rank = gpc.get_local_rank(parallel_mode)
out = torch.chunk(temp, depth, dim=dim)[rank].contiguous() # out = torch.chunk(temp, depth, dim=dim)[rank].contiguous()
return out # return out
...@@ -17,8 +17,6 @@ def _communicate(tensor_send_next=None, ...@@ -17,8 +17,6 @@ def _communicate(tensor_send_next=None,
recv_next_shape=None, recv_next_shape=None,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None,
up_group=None,
down_group=None,
dtype=None): dtype=None):
""" """
Adapted from megatron.p2p_communication. Adapted from megatron.p2p_communication.
...@@ -59,60 +57,44 @@ def _communicate(tensor_send_next=None, ...@@ -59,60 +57,44 @@ def _communicate(tensor_send_next=None,
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank( prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE) ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
if tensor_send_next is not None or recv_next: if tensor_send_next is not None or recv_next:
if next_rank is None: if next_rank is None:
next_rank = gpc.get_next_global_rank( next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE) ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
# rank = dist.get_rank() # rank = dist.get_rank()
rank = gpc.get_global_rank() rank = gpc.get_global_rank()
ops = [] ops = []
if tensor_send_prev is not None: if tensor_send_prev is not None:
send_prev_op = dist.broadcast(tensor_send_prev, send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
src=rank,
group=up_group,
async_op=True)
ops.append(send_prev_op) ops.append(send_prev_op)
if tensor_recv_prev is not None: if tensor_recv_prev is not None:
recv_prev_op = dist.broadcast(tensor_recv_prev, recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
src=prev_rank,
group=up_group,
async_op=True)
ops.append(recv_prev_op) ops.append(recv_prev_op)
if tensor_recv_next is not None: if tensor_recv_next is not None:
recv_next_op = dist.broadcast(tensor_recv_next, recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
src=next_rank,
group=down_group,
async_op=True)
ops.append(recv_next_op) ops.append(recv_next_op)
if tensor_send_next is not None: if tensor_send_next is not None:
send_next_op = dist.broadcast(tensor_send_next, send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
src=rank,
group=down_group,
async_op=True)
ops.append(send_next_op) ops.append(send_next_op)
for req in ops: if len(ops) > 0:
req.wait() reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, up_group=None): def recv_forward(input_tensor_shape, prev_rank=None):
"""Receives the input tensor from the previous member in pipeline. """Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved :param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor :param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_shape: torch.Size :type input_tensor_shape: torch.Size
:type prev_rank: int, optional :type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The input tensor in forward step :return: The input tensor in forward step
:rtype: Tensor :rtype: Tensor
""" """
...@@ -121,20 +103,17 @@ def recv_forward(input_tensor_shape, prev_rank=None, up_group=None): ...@@ -121,20 +103,17 @@ def recv_forward(input_tensor_shape, prev_rank=None, up_group=None):
else: else:
input_tensor, _ = _communicate(recv_prev=True, input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank)
up_group=up_group)
return input_tensor return input_tensor
def recv_backward(output_grad_shape, next_rank=None, down_group=None): def recv_backward(output_grad_shape, next_rank=None):
"""Receives the grad tensor from the next member in pipeline. """Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved :param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor :param next_rank: The rank of the source of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_grad_shape: torch.Size :type output_grad_shape: torch.Size
:type next_rank: int, optional :type next_rank: int, optional
:type down_group: ProcessGroup, optional
:return: The grad of output tensor in forward step :return: The grad of output tensor in forward step
:rtype: Tensor :rtype: Tensor
""" """
...@@ -143,56 +122,44 @@ def recv_backward(output_grad_shape, next_rank=None, down_group=None): ...@@ -143,56 +122,44 @@ def recv_backward(output_grad_shape, next_rank=None, down_group=None):
else: else:
_, output_tensor_grad = _communicate(recv_next=True, _, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
next_rank=next_rank, next_rank=next_rank)
down_group=down_group)
return output_tensor_grad return output_tensor_grad
def send_forward(output_tensor, def send_forward(output_tensor, next_rank=None):
next_rank=None,
down_group=None):
"""Sends the input tensor to the next member in pipeline. """Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent :param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor :param next_rank: The rank of the recipient of the tensor
:param down_group: Communication group including the next member in pipeline parallel group
:type output_tensor: Tensor :type output_tensor: Tensor
:type next_rank: int, optional :type next_rank: int, optional
:type down_group: ProcessGroup, optional
""" """
if not gpc.is_last_rank(ParallelMode.PIPELINE): if not gpc.is_last_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_next=output_tensor, _communicate(tensor_send_next=output_tensor,
next_rank=next_rank, next_rank=next_rank)
down_group=down_group)
def send_backward(input_tensor_grad, def send_backward(input_tensor_grad, prev_rank=None):
prev_rank=None,
up_group=None):
"""Sends the grad tensor to the previous member in pipeline. """Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent :param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor :param prev_rank: The rank of the recipient of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type input_tensor_grad: Tensor :type input_tensor_grad: Tensor
:type prev_rank: int, optional :type prev_rank: int, optional
:type up_group: ProcessGroup, optional
""" """
if not gpc.is_first_rank(ParallelMode.PIPELINE): if not gpc.is_first_rank(ParallelMode.PIPELINE):
_communicate(tensor_send_prev=input_tensor_grad, _communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank, prev_rank=prev_rank)
up_group=up_group)
def send_forward_recv_backward(output_tensor, def send_forward_recv_backward(output_tensor,
output_grad_shape, output_grad_shape,
recv_next=True, recv_next=True,
next_rank=None, next_rank=None):
down_group=None):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the next member in pipeline, while recieves the grad tensor from the
next member in pipeline. next member in pipeline.
:param output_tensor: Tensor to be sent :param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved :param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor :type output_tensor: Tensor
...@@ -206,20 +173,18 @@ def send_forward_recv_backward(output_tensor, ...@@ -206,20 +173,18 @@ def send_forward_recv_backward(output_tensor,
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor, _, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next, recv_next=recv_next,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
next_rank=next_rank, next_rank=next_rank)
down_group=down_group)
return output_tensor_grad return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad, def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape, input_tensor_shape,
recv_prev=True, recv_prev=True,
prev_rank=None, prev_rank=None):
up_group=None):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the previous member in pipeline, while recieves the input tensor from the
previous member in pipeline. previous member in pipeline.
:param input_tensor_grad: Tensor to be sent :param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved :param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor :type input_tensor_grad: Tensor
...@@ -233,8 +198,7 @@ def send_backward_recv_forward(input_tensor_grad, ...@@ -233,8 +198,7 @@ def send_backward_recv_forward(input_tensor_grad,
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad, input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank)
up_group=up_group)
return input_tensor return input_tensor
...@@ -242,13 +206,11 @@ def send_forward_recv_forward(output_tensor, ...@@ -242,13 +206,11 @@ def send_forward_recv_forward(output_tensor,
input_tensor_shape, input_tensor_shape,
recv_prev=True, recv_prev=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None):
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the """Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the next member in pipeline, while recieves the input tensor from the
previous member in pipeline. previous member in pipeline.
:param output_tensor: Tensor to be sent :param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved :param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: Tensor :type output_tensor: Tensor
...@@ -260,9 +222,7 @@ def send_forward_recv_forward(output_tensor, ...@@ -260,9 +222,7 @@ def send_forward_recv_forward(output_tensor,
recv_prev=recv_prev, recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank)
up_group=up_group,
down_group=down_group)
return input_tensor return input_tensor
...@@ -270,13 +230,11 @@ def send_backward_recv_backward(input_tensor_grad, ...@@ -270,13 +230,11 @@ def send_backward_recv_backward(input_tensor_grad,
output_grad_shape, output_grad_shape,
recv_next=True, recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None):
up_group=None,
down_group=None):
"""Batched communication operation. Sends the grad tensor to the """Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the previous member in pipeline, while recieves the grad tensor from the
next member in pipeline. next member in pipeline.
:param input_tensor_grad: Tensor to be sent :param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved :param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: Tensor :type input_tensor_grad: Tensor
...@@ -288,9 +246,7 @@ def send_backward_recv_backward(input_tensor_grad, ...@@ -288,9 +246,7 @@ def send_backward_recv_backward(input_tensor_grad,
recv_next=recv_next, recv_next=recv_next,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank)
up_group=up_group,
down_group=down_group)
return output_tensor_grad return output_tensor_grad
...@@ -301,13 +257,11 @@ def send_forward_backward_recv_forward_backward(output_tensor, ...@@ -301,13 +257,11 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev=True, recv_prev=True,
recv_next=True, recv_next=True,
prev_rank=None, prev_rank=None,
next_rank=None, next_rank=None):
up_group=None,
down_group=None):
"""Batched communication operation. Sends the input tensor to the next and """Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous. next and the input tensor from the previous.
:param output_tensor: Tensor sent to the next :param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous :param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous :param input_tensor_shape: The shape of the tensor recieved from the previous
...@@ -327,7 +281,5 @@ def send_forward_backward_recv_forward_backward(output_tensor, ...@@ -327,7 +281,5 @@ def send_forward_backward_recv_forward_backward(output_tensor,
recv_prev_shape=input_tensor_shape, recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape, recv_next_shape=output_grad_shape,
prev_rank=prev_rank, prev_rank=prev_rank,
next_rank=next_rank, next_rank=next_rank)
up_group=up_group,
down_group=down_group)
return input_tensor, output_tensor_grad return input_tensor, output_tensor_grad
...@@ -6,7 +6,7 @@ from colossalai.core import global_context as gpc ...@@ -6,7 +6,7 @@ from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, down_group=None): def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor. """Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function meta information of the tensor should be sent before communications. This function
...@@ -14,31 +14,34 @@ def send_tensor_meta(tensor, need_meta=True, down_group=None): ...@@ -14,31 +14,34 @@ def send_tensor_meta(tensor, need_meta=True, down_group=None):
:param tensor: Tensor to be sent :param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent :param need_meta: If False, meta information won't be sent
:param down_group: Communication group including the next member in pipeline parallel group :param next_rank: The rank of the next member in pipeline parallel group
:type tensor: Tensor :type tensor: Tensor
:type need_meta: bool, optional :type need_meta: bool, optional
:type down_group: ProcessGroup, optional :type next_rank: int
:return: False :return: False
:rtype: bool :rtype: bool
""" """
if need_meta: if need_meta:
rank = gpc.get_global_rank() if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
if down_group is None:
down_group = gpc.get_group(ParallelMode.PIPELINE_NEXT)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs) send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs) send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
ops = [
dist.broadcast(send_ndims, src=rank, group=down_group) dist.P2POp(dist.isend, send_ndims, next_rank),
dist.broadcast(send_shape, src=rank, group=down_group) dist.P2POp(dist.isend, send_shape, next_rank)
]
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
torch.cuda.synchronize()
return False return False
def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None): def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor. """Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications, Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function meta information of the tensor should be recieved before communications. This function
...@@ -46,27 +49,21 @@ def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None): ...@@ -46,27 +49,21 @@ def recv_tensor_meta(tensor_shape, prev_rank=None, up_group=None):
:param tensor_shape: The shape of the tensor to be recieved :param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor :param prev_rank: The rank of the source of the tensor
:param up_group: Communication group including the previous member in pipeline parallel group
:type tensor_shape: torch.Size :type tensor_shape: torch.Size
:type prev_rank: int, optional :type prev_rank: int, optional
:type up_group: ProcessGroup, optional
:return: The shape of the tensor to be recieved :return: The shape of the tensor to be recieved
:rtype: torch.Size :rtype: torch.Size
""" """
if tensor_shape is None: if tensor_shape is None:
if prev_rank is None: if prev_rank is None:
prev_rank = gpc.get_prev_global_rank( prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
ParallelMode.PIPELINE)
if up_group is None:
up_group = gpc.get_group(ParallelMode.PIPELINE_PREV)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()} tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_ndims = torch.empty((), **tensor_kwargs) recv_ndims = torch.empty((), **tensor_kwargs)
dist.broadcast(recv_ndims, src=prev_rank, group=up_group) dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs) recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.broadcast(recv_shape, src=prev_rank, group=up_group) dist.recv(recv_shape, prev_rank)
tensor_shape = torch.Size(recv_shape) tensor_shape = torch.Size(recv_shape)
......
...@@ -25,7 +25,11 @@ TESSERACT_DEP = 'TESSERACT_DEP' ...@@ -25,7 +25,11 @@ TESSERACT_DEP = 'TESSERACT_DEP'
# 3D parallel # 3D parallel
DEPTH_3D = 'DEPTH_3D' DEPTH_3D = 'DEPTH_3D'
INPUT_GROUP_3D = 'PARALLEL_3D_INPUT'
WEIGHT_GROUP_3D = 'PARALLEL_3D_WEIGHT'
OUTPUT_GROUP_3D = 'PARALLEL_3D_OUTPUT'
# Tensor parallel attributes # Tensor parallel attributes
IS_TENSOR_PARALLEL = 'is_tensor_parallel' IS_TENSOR_PARALLEL = 'is_tensor_parallel'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL] NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
from .config import Config from .config import Config, ConfigException
from .parallel_context import ParallelContext from .parallel_context import ParallelContext
from .parallel_context import ParallelMode from .parallel_mode import ParallelMode
from .process_group_initializer import * from .process_group_initializer import *
from .random import * from .random import *
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment