Unverified Commit b5f9e37c authored by Hongxin Liu's avatar Hongxin Liu Committed by GitHub
Browse files

[legacy] clean up legacy code (#4743)

* [legacy] remove outdated codes of pipeline (#4692)

* [legacy] remove cli of benchmark and update optim (#4690)

* [legacy] remove cli of benchmark and update optim

* [doc] fix cli doc test

* [legacy] fix engine clip grad norm

* [legacy] remove outdated colo tensor (#4694)

* [legacy] remove outdated colo tensor

* [test] fix test import

* [legacy] move outdated zero to legacy (#4696)

* [legacy] clean up utils (#4700)

* [legacy] clean up utils

* [example] update examples

* [legacy] clean up amp

* [legacy] fix amp module

* [legacy] clean up gpc (#4742)

* [legacy] clean up context

* [legacy] clean core, constants and global vars

* [legacy] refactor initialize

* [example] fix examples ci

* [example] fix examples ci

* [legacy] fix tests

* [example] fix gpt example

* [example] fix examples ci

* [devops] fix ci installation

* [example] fix examples ci
parent 32e7f994
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer
from colossalai.context import Config
from .amp_type import AMP_TYPE
from .apex_amp import convert_to_apex_amp
from .naive_amp import convert_to_naive_amp
from .torch_amp import convert_to_torch_amp
__all__ = ['convert_to_amp', 'convert_to_naive_amp', 'convert_to_apex_amp', 'convert_to_torch_amp', 'AMP_TYPE']
def convert_to_amp(model: nn.Module, optimizer: Optimizer, criterion: _Loss, mode: AMP_TYPE, amp_config: Config = None):
"""A helper function to wrap training components with Torch AMP modules.
Args:
param model (:class:`torch.nn.Module`): your model object.
optimizer (:class:`torch.optim.Optimizer`): your optimizer object.
criterion (:class:`torch.nn.modules.loss._Loss`): your loss function object.
mode (:class:`colossalai.legacy.amp.AMP_TYPE`): amp mode.
amp_config (Union[:class:`colossalai.context.Config`, dict]): configuration for different amp modes.
Returns:
A tuple (model, optimizer, criterion).
Note:
``amp_config`` may vary from different mode you choose. You should check the corresponding amp mode
for more details about ``amp_config``.
For ``apex_amp``, please check
`apex_amp config <https://nvidia.github.io/apex/amp.html?highlight=apex%20amp>`_.
For ``naive_amp``, please check
`naive_amp config <https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/amp/naive_amp/_fp16_optimizer.py#L42>`_.
For ``torch_amp``, please check
`torch_amp config <https://github.com/pytorch/pytorch/blob/master/torch/cuda/amp/grad_scaler.py#L97>`_.
"""
assert isinstance(mode, AMP_TYPE), \
f'expected the argument mode be AMP_TYPE, but got {type(mode)}'
if amp_config is None:
amp_config = Config()
if mode == AMP_TYPE.TORCH:
model, optimizer, criterion = convert_to_torch_amp(model, optimizer, criterion, amp_config)
elif mode == AMP_TYPE.APEX:
model, optimizer = convert_to_apex_amp(model, optimizer, amp_config)
elif mode == AMP_TYPE.NAIVE:
model, optimizer = convert_to_naive_amp(model, optimizer, amp_config)
return model, optimizer, criterion
...@@ -10,11 +10,11 @@ except ImportError: ...@@ -10,11 +10,11 @@ except ImportError:
from torch import Tensor from torch import Tensor
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.interface import OptimizerWrapper
from colossalai.utils import clip_grad_norm_fp32 from colossalai.legacy.utils import clip_grad_norm_fp32
class ApexAMPOptimizer(ColossalaiOptimizer): class ApexAMPOptimizer(OptimizerWrapper):
""" A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm """ A wrapper class for APEX optimizer and it implements apex-specific backward and clip_grad_norm
methods methods
""" """
......
import inspect
import torch.nn as nn
from torch.optim import Optimizer
from colossalai.amp.naive_amp.grad_scaler import ConstantGradScaler, DynamicGradScaler
from colossalai.legacy.utils import is_no_pp_or_last_stage
from ._fp16_optimizer import FP16Optimizer
from .naive_amp import NaiveAMPModel, NaiveAMPOptimizer
def convert_to_naive_amp(model: nn.Module, optimizer: Optimizer, amp_config):
"""A helper function to wrap training components with naive AMP modules. In this mode,
we forcibly cast the model weights and inputs to FP16, and cast the model outputs to FP32 to calculate loss,
which is equivalent to Apex O3.
Args:
model (:class:`torch.nn.Module`): your model object
optimizer (:class:`torch.optim.Optimizer`): your optimizer object
amp_config (:class:`colossalai.context.Config` or dict): configuration for naive mode amp.
Returns:
Tuple: A tuple (model, optimizer)
The ``amp_config`` should contain parameters below::
verbose (bool, optional): if set to `True`, will print debug info (Default: False).
clip_grad_norm (float, optional): clip gradients with this global L2 norm (Default 0).
Note that clipping is ignored if clip_grad == 0.
dynamic_grad_scale (bool): whether to use dynamic grad scaler.
"""
if isinstance(model, nn.ModuleList):
# interleaved pipeline
module_list = []
for chunk, m in enumerate(model):
output_to_fp32 = is_no_pp_or_last_stage() and chunk == len(model) - 1
module_list.append(NaiveAMPModel(m, output_to_fp32=output_to_fp32))
model = nn.ModuleList(module_list)
else:
output_to_fp32 = is_no_pp_or_last_stage()
model = NaiveAMPModel(model, output_to_fp32=output_to_fp32)
use_dynamic_grad_scaler = amp_config.pop('dynamic_grad_scale', True)
if use_dynamic_grad_scaler:
scaler_class = DynamicGradScaler
else:
scaler_class = ConstantGradScaler
sig = inspect.signature(scaler_class.__init__)
kwargs = dict()
for param in sig.parameters.values():
if param.name in amp_config:
kwargs[param.name] = amp_config.pop(param.name)
grad_scaler = scaler_class(**kwargs)
optimizer = NaiveAMPOptimizer(optimizer, grad_scaler, **amp_config)
return model, optimizer
__all__ = ['convert_to_naive_amp', 'NaiveAMPOptimizer', 'FP16Optimizer']
...@@ -6,14 +6,15 @@ import torch.distributed as dist ...@@ -6,14 +6,15 @@ import torch.distributed as dist
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.context import ParallelMode from colossalai.amp.naive_amp.grad_scaler import BaseGradScaler
from colossalai.core import global_context as gpc
from colossalai.kernel.op_builder import FusedOptimBuilder from colossalai.kernel.op_builder import FusedOptimBuilder
from colossalai.legacy.context import ParallelMode
from colossalai.legacy.core import global_context as gpc
from colossalai.legacy.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.utils import clip_grad_norm_fp32, copy_tensor_parallel_attributes, multi_tensor_applier from colossalai.utils import multi_tensor_applier
from ._utils import has_inf_or_nan, zero_gard_by_list from ._utils import has_inf_or_nan, zero_gard_by_list
from .grad_scaler import BaseGradScaler
try: try:
from colossalai._C import fused_optim from colossalai._C import fused_optim
......
...@@ -11,14 +11,14 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors ...@@ -11,14 +11,14 @@ from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.context import ParallelMode from colossalai.interface import OptimizerWrapper
from colossalai.core import global_context as gpc from colossalai.legacy.context import ParallelMode
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.legacy.core import global_context as gpc
from ._fp16_optimizer import FP16Optimizer from ._fp16_optimizer import FP16Optimizer
class NaiveAMPOptimizer(ColossalaiOptimizer): class NaiveAMPOptimizer(OptimizerWrapper):
"""A wrapper class for optimizer to cast all parameters to fp16 """A wrapper class for optimizer to cast all parameters to fp16
Args: Args:
...@@ -57,7 +57,7 @@ class NaiveAMPModel(nn.Module): ...@@ -57,7 +57,7 @@ class NaiveAMPModel(nn.Module):
Args: Args:
model (torch.nn.Module): torch.nn.Module to be wrapped. model (torch.nn.Module): torch.nn.Module to be wrapped.
output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True) output_to_fp32 (bool, optional): Whether cast output of this module into fp32. (Default: True)
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this module. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this module.
(Default: ``ParallelMode.DATA``) (Default: ``ParallelMode.DATA``)
sync_buffer (bool, optional): whether to synchronize buffer. (Default: True) sync_buffer (bool, optional): whether to synchronize buffer. (Default: True)
......
...@@ -13,8 +13,8 @@ import torch.distributed as dist ...@@ -13,8 +13,8 @@ import torch.distributed as dist
from packaging import version from packaging import version
from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors from torch._utils import _flatten_dense_tensors, _unflatten_dense_tensors
from colossalai.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
class _MultiDeviceReplicator(object): class _MultiDeviceReplicator(object):
......
...@@ -7,13 +7,13 @@ from torch import Tensor ...@@ -7,13 +7,13 @@ from torch import Tensor
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from torch.optim import Optimizer from torch.optim import Optimizer
from colossalai.nn.optimizer import ColossalaiOptimizer from colossalai.interface import OptimizerWrapper
from colossalai.utils import clip_grad_norm_fp32 from colossalai.legacy.utils import clip_grad_norm_fp32
from ._grad_scaler import GradScaler from ._grad_scaler import GradScaler
class TorchAMPOptimizer(ColossalaiOptimizer): class TorchAMPOptimizer(OptimizerWrapper):
"""A wrapper class which integrate Pytorch AMP with an optimizer """A wrapper class which integrate Pytorch AMP with an optimizer
Args: Args:
......
...@@ -6,8 +6,8 @@ import torch.distributed as dist ...@@ -6,8 +6,8 @@ import torch.distributed as dist
from torch import Tensor from torch import Tensor
from torch.distributed import ReduceOp from torch.distributed import ReduceOp
from colossalai.context import ParallelMode from colossalai.legacy.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
_all_gather_func = dist._all_gather_base \ _all_gather_func = dist._all_gather_base \
if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor if "all_gather_into_tensor" not in dir(dist) else dist.all_gather_into_tensor
...@@ -26,7 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: ...@@ -26,7 +26,7 @@ def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op:
Args: Args:
tensor (:class:`torch.Tensor`): Tensor to be gathered. tensor (:class:`torch.Tensor`): Tensor to be gathered.
dim (int): The dimension concatenating in. dim (int): The dimension concatenating in.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
...@@ -65,7 +65,7 @@ def reduce_scatter(tensor: Tensor, ...@@ -65,7 +65,7 @@ def reduce_scatter(tensor: Tensor,
Args: Args:
tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered. tensor (:class:`torch.Tensor`): Tensor to be reduce_scattered.
dim (int): The dimension concatenating in. dim (int): The dimension concatenating in.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
op (torch.distributed.ReduceOp, optional): The type of reduce operation, op (torch.distributed.ReduceOp, optional): The type of reduce operation,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
...@@ -105,7 +105,7 @@ def all_reduce(tensor: Tensor, ...@@ -105,7 +105,7 @@ def all_reduce(tensor: Tensor,
Args: Args:
tensor (:class:`torch.Tensor`): Tensor to be all-reduced. tensor (:class:`torch.Tensor`): Tensor to be all-reduced.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
op (torch.distributed.ReduceOp, optional): The type of reduce operation, op (torch.distributed.ReduceOp, optional): The type of reduce operation,
should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR]. should be included in [SUM, AVG, PRODUCT, MIN, MAX, BAND, BOR, BXOR].
More details about ReduceOp please refer to More details about ReduceOp please refer to
...@@ -141,7 +141,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b ...@@ -141,7 +141,7 @@ def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: b
Args: Args:
tensor (:class:`torch.Tensor`): Tensor to be broadcast. tensor (:class:`torch.Tensor`): Tensor to be broadcast.
src (int): Source rank. src (int): Source rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
...@@ -173,7 +173,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ...@@ -173,7 +173,7 @@ def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp =
Args: Args:
tensor (:class:`torch.Tensor`): Tensor to be reduced. tensor (:class:`torch.Tensor`): Tensor to be reduced.
dst (int): Destination rank. dst (int): Destination rank.
parallel_mode (:class:`colossalai.context.ParallelMode`): Parallel group mode used in this communication. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): Parallel group mode used in this communication.
async_op (bool, optional): Whether operations are asynchronous. async_op (bool, optional): Whether operations are asynchronous.
Returns: Returns:
......
...@@ -8,8 +8,8 @@ from typing import List, Tuple, Union ...@@ -8,8 +8,8 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks from .utils import gather_split_1d_tensor, split_tensor_into_1d_equal_chunks
......
...@@ -10,8 +10,8 @@ import torch.distributed as dist ...@@ -10,8 +10,8 @@ import torch.distributed as dist
from torch.distributed import ProcessGroupNCCL from torch.distributed import ProcessGroupNCCL
from torch.distributed import distributed_c10d as c10d from torch.distributed import distributed_c10d as c10d
from colossalai.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
TensorShape = Union[torch.Size, List[int], Tuple[int]] TensorShape = Union[torch.Size, List[int], Tuple[int]]
_pg_manager = {} _pg_manager = {}
......
...@@ -3,8 +3,8 @@ ...@@ -3,8 +3,8 @@
import torch import torch
from colossalai.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize from colossalai.utils import get_current_device, synchronize
......
...@@ -3,8 +3,8 @@ from typing import List, Tuple, Union ...@@ -3,8 +3,8 @@ from typing import List, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode from colossalai.legacy.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.legacy.core import global_context as gpc
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
TensorShape = Union[torch.Size, List[int], Tuple[int]] TensorShape = Union[torch.Size, List[int], Tuple[int]]
......
from .parallel_context import ParallelContext
from .parallel_mode import ParallelMode
from .process_group_initializer import *
from .random import *
...@@ -11,10 +11,10 @@ import numpy as np ...@@ -11,10 +11,10 @@ import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config from colossalai.context.config import Config
from colossalai.context.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.global_variables import tensor_parallel_env as env from colossalai.legacy.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.legacy.global_variables import tensor_parallel_env as env
from colossalai.legacy.registry import DIST_GROUP_INITIALIZER from colossalai.legacy.registry import DIST_GROUP_INITIALIZER
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
...@@ -110,12 +110,12 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -110,12 +110,12 @@ class ParallelContext(metaclass=SingletonMeta):
"""Adds the global rank of the current device for `parallel_mode` to the context. """Adds the global rank of the current device for `parallel_mode` to the context.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added rank (int): The rank to be added
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._global_ranks[parallel_mode] = rank self._global_ranks[parallel_mode] = rank
...@@ -124,11 +124,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -124,11 +124,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns the local rank of the current device. """Returns the local rank of the current device.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
int: The local rank of the current device for `parallel_mode`. int: The local rank of the current device for `parallel_mode`.
...@@ -140,12 +140,12 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -140,12 +140,12 @@ class ParallelContext(metaclass=SingletonMeta):
"""Adds the local rank of the current device for `parallel_mode` to the context. """Adds the local rank of the current device for `parallel_mode` to the context.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode for the rank. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode for the rank.
rank (int): The rank to be added. rank (int): The rank to be added.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank self._local_ranks[parallel_mode] = rank
...@@ -154,11 +154,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -154,11 +154,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns the global rank of the next device. """Returns the global rank of the next device.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
int: The global rank of the next device for `parallel_mode`. int: The global rank of the next device for `parallel_mode`.
...@@ -176,11 +176,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -176,11 +176,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns the global rank of the previous device. """Returns the global rank of the previous device.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
int: The global rank of the previous device for `parallel_mode`. int: The global rank of the previous device for `parallel_mode`.
...@@ -199,11 +199,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -199,11 +199,11 @@ class ParallelContext(metaclass=SingletonMeta):
among its group for `parallel_mode`. among its group for `parallel_mode`.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
bool: a boolean value indicating whether the current device is the first one bool: a boolean value indicating whether the current device is the first one
...@@ -217,11 +217,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -217,11 +217,11 @@ class ParallelContext(metaclass=SingletonMeta):
among its group for `parallel_mode`. among its group for `parallel_mode`.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
bool: a boolean value indicating whether the current device is the first one bool: a boolean value indicating whether the current device is the first one
...@@ -248,11 +248,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -248,11 +248,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns the world size for `parallel_mode`. """Returns the world size for `parallel_mode`.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
int: The world size for `parallel_mode`. int: The world size for `parallel_mode`.
...@@ -264,12 +264,12 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -264,12 +264,12 @@ class ParallelContext(metaclass=SingletonMeta):
"""Adds world size for `parallel_mode`. """Adds world size for `parallel_mode`.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The parallel mode corresponding to the process group parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The parallel mode corresponding to the process group
world_size (int): The world size to be added world_size (int): The world size to be added
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._world_sizes[parallel_mode] = world_size self._world_sizes[parallel_mode] = world_size
...@@ -278,11 +278,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -278,11 +278,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns the group of the current device for `parallel_mode`. """Returns the group of the current device for `parallel_mode`.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`. torch.distributed.ProcessGroup: The group of the current device for `parallel_mode`.
...@@ -294,12 +294,12 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -294,12 +294,12 @@ class ParallelContext(metaclass=SingletonMeta):
"""Adds the group of the current device for `parallel_mode`. """Adds the group of the current device for `parallel_mode`.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
group (torch.distributed.ProcessGroup): The group to be added group (torch.distributed.ProcessGroup): The group to be added
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group self._groups[parallel_mode] = group
...@@ -308,9 +308,9 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -308,9 +308,9 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns the Gloo group of the current device for `parallel_mode`. """Returns the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode :param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode` :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` of :class:`colossalai.legacy.context.ParallelMode`
:return: The group of the current device for `parallel_mode` :return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup :rtype: torch.distributed.ProcessGroup
""" """
...@@ -321,11 +321,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -321,11 +321,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Adds the Gloo group of the current device for `parallel_mode`. """Adds the Gloo group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode :param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode` :type parallel_mode: :class:`colossalai.legacy.context.ParallelMode`
:param group: The group to be added :param group: The group to be added
:type group: torch.distributed.ProcessGroup :type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance :raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode` of :class:`colossalai.legacy.context.ParallelMode`
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._cpu_groups[parallel_mode] = group self._cpu_groups[parallel_mode] = group
...@@ -334,11 +334,11 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -334,11 +334,11 @@ class ParallelContext(metaclass=SingletonMeta):
"""Returns the rank of the current device for `parallel_mode` in the group. """Returns the rank of the current device for `parallel_mode` in the group.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
Returns: Returns:
int: The rank of the current device for `parallel_mode` in the group. int: The rank of the current device for `parallel_mode` in the group.
...@@ -350,12 +350,12 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -350,12 +350,12 @@ class ParallelContext(metaclass=SingletonMeta):
"""Adds the ranks of the current device for `parallel_mode` in the group. """Adds the ranks of the current device for `parallel_mode` in the group.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
ranks (list): List of ranks to be added ranks (list): List of ranks to be added
Raises: Raises:
AssertionError: Raises an AssertionError if `parallel_mode` is not an instance AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`. of :class:`colossalai.legacy.context.ParallelMode`.
""" """
self._check_parallel_mode(parallel_mode) self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks self._ranks_in_group[parallel_mode] = ranks
...@@ -489,7 +489,7 @@ class ParallelContext(metaclass=SingletonMeta): ...@@ -489,7 +489,7 @@ class ParallelContext(metaclass=SingletonMeta):
in the current system. in the current system.
Args: Args:
parallel_mode (:class:`colossalai.context.ParallelMode`): The chosen parallel mode. parallel_mode (:class:`colossalai.legacy.context.ParallelMode`): The chosen parallel mode.
Returns: Returns:
bool: a boolean value indicating whether `parallel_mode` is initialized in the current system. bool: a boolean value indicating whether `parallel_mode` is initialized in the current system.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment