Commit da3f0934 authored by zhuwenwen's avatar zhuwenwen
Browse files

delete unused files

parent c4dd1fd4
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.distributed as dist
from torch.distributed import ReduceOp
from torch import Tensor
from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def all_gather(tensor: Tensor, dim: int, parallel_mode: ParallelMode, async_op: bool = False) -> Tensor:
"""Gathers all tensors from the parallel group and concatenates them in a
specific dimension.
:param tensor: Tensor to be gathered
:param dim: The dimension concatenating in
:param parallel_mode: Parallel group mode used in this communication
:param async_op: Whether operations are asynchronous
:type tensor: :class:`torch.Tensor`
:type dim: int
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:type async_op: bool, optional
:return: The tensor generated by all-gather
:rtype: :class:`torch.Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
shape = list(tensor.shape)
shape[0], shape[dim] = shape[dim], shape[0]
shape[0] *= depth
out = torch.empty(shape, dtype=tensor.dtype, device=get_current_device())
temp = list(torch.chunk(out, depth, dim=0))
work = dist.all_gather(tensor_list=temp,
tensor=tensor.transpose(0, dim).contiguous(),
group=gpc.get_group(parallel_mode),
async_op=async_op)
out = torch.transpose(out, 0, dim)
if async_op:
return out, work
else:
return out
def reduce_scatter(tensor: Tensor,
dim: int,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
"""Reduces all tensors then scatters it in a specific dimension to all
members in the parallel group.
:param tensor: Tensor to be reduced and scattered
:param dim: The dimension scattering in
:param parallel_mode: Parallel group mode used in this communication
:param op: The type of reduce operation
:param async_op: Whether operations are asynchronous
:type tensor: :class:`torch.Tensor`
:type dim: int
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:type op: ReduceOp, optional
:type async_op: bool, optional
:return: The tensor generated by reduce-scatter
:rtype: :class:`Tensor`
"""
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
temp = list(map(lambda x: x.contiguous(), torch.chunk(tensor, depth, dim=dim)))
out = torch.empty(temp[0].shape, dtype=tensor.dtype, device=get_current_device())
work = dist.reduce_scatter(output=out,
input_list=temp,
op=op,
group=gpc.get_group(parallel_mode),
async_op=async_op)
if async_op:
return out, work
else:
return out
def all_reduce(tensor: Tensor,
parallel_mode: ParallelMode,
op: ReduceOp = ReduceOp.SUM,
async_op: bool = False) -> Tensor:
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
out = tensor.contiguous()
work = dist.all_reduce(out, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return out, work
else:
return out
def broadcast(tensor: Tensor, src: int, parallel_mode: ParallelMode, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
out = tensor.contiguous()
work = dist.broadcast(out, src=src, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return out, work
else:
return out
def reduce(tensor: Tensor, dst: int, parallel_mode: ParallelMode, op: ReduceOp = ReduceOp.SUM, async_op: bool = False):
depth = gpc.get_world_size(parallel_mode)
if depth == 1:
out = tensor
work = None
else:
out = tensor.contiguous()
work = dist.reduce(out, dst=dst, op=op, group=gpc.get_group(parallel_mode), async_op=async_op)
if async_op:
return out, work
else:
return out
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import List, Tuple, Union
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
from functools import reduce
import operator
from .utils import split_tensor_into_1d_equal_chunks, gather_split_1d_tensor
TensorShape = Union[torch.Size, List[int], Tuple[int]]
def _get_tensor_shape(tensor_shape: TensorShape, chunk_tensor: bool = False) -> Tuple[TensorShape, bool]:
"""get the exact tensor shape when communicating and return whether the tensor is a chunk
:param tensor_shape: shape of tensor
:type tensor_shape: TensorShape
:param chunk_tensor: whether to chunk tensor, defaults to False
:type chunk_tensor: bool, optional
:return: exact tensor shape, whether to chunk tensor
:rtype: Tuple[Union[torch.Size, List[int], Tuple[int]], bool]
"""
if chunk_tensor:
tensor_chunk_shape = reduce(operator.mul, tensor_shape, 1)
tensor_parallel_world_size = gpc.get_world_size(ParallelMode.TENSOR)
if tensor_chunk_shape % tensor_parallel_world_size == 0:
tensor_chunk_shape = tensor_chunk_shape // tensor_parallel_world_size
else:
tensor_chunk_shape = tensor_shape
chunk_tensor = False
else:
tensor_chunk_shape = tensor_shape
return tensor_chunk_shape, chunk_tensor
def _communicate(tensor_send_next=None,
tensor_send_prev=None,
recv_prev=False,
recv_next=False,
recv_prev_shape=None,
recv_next_shape=None,
prev_rank=None,
next_rank=None,
dtype=None,
scatter_gather_tensors=False):
"""
Adapted from megatron.p2p_communication.
Communicate tensors between stages. Used as helper method in other
communication methods that are used in pipeline schedule.
Takes the following arguments:
tensor_send_next: tensor to send to next rank (no tensor sent if
set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if
set to None).
recv_prev: boolean for whether tensor should be received from
previous rank.
recv_next: boolean for whether tensor should be received from
next rank.
Returns:
(tensor_recv_prev, tensor_recv_next)
"""
# Create placeholder tensors for receive in forward and backward directions
# if needed.
tensor_recv_prev = None
tensor_recv_next = None
if recv_prev:
assert recv_prev_shape is not None
recv_prev_chunk_shape, recv_prev_split = _get_tensor_shape(recv_prev_shape, scatter_gather_tensors)
tensor_recv_prev = torch.empty(recv_prev_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if recv_next:
assert recv_next_shape is not None
recv_next_chunk_shape, recv_next_split = _get_tensor_shape(recv_next_shape, scatter_gather_tensors)
tensor_recv_next = torch.empty(recv_next_chunk_shape,
requires_grad=True,
device=get_current_device(),
dtype=dtype)
if tensor_send_prev is not None or recv_prev:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(
ParallelMode.PIPELINE)
if tensor_send_next is not None or recv_next:
if next_rank is None:
next_rank = gpc.get_next_global_rank(
ParallelMode.PIPELINE)
if tensor_send_prev is not None:
send_prev_split = _get_tensor_shape(tensor_send_prev.shape, scatter_gather_tensors)[1]
if send_prev_split:
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
if tensor_send_next is not None:
send_next_split = _get_tensor_shape(tensor_send_next.shape, scatter_gather_tensors)[1]
if send_next_split:
tensor_send_next = split_tensor_into_1d_equal_chunks(tensor_send_next)
ops = []
if tensor_send_prev is not None:
send_prev_op = dist.P2POp(dist.isend, tensor_send_prev, prev_rank)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = dist.P2POp(dist.irecv, tensor_recv_prev, prev_rank)
ops.append(recv_prev_op)
if tensor_recv_next is not None:
recv_next_op = dist.P2POp(dist.irecv, tensor_recv_next, next_rank)
ops.append(recv_next_op)
if tensor_send_next is not None:
send_next_op = dist.P2POp(dist.isend, tensor_send_next, next_rank)
ops.append(send_next_op)
if len(ops) > 0:
reqs = dist.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
if recv_prev and recv_prev_split:
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
if recv_next and recv_next_split:
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
return tensor_recv_prev, tensor_recv_next
def recv_forward(input_tensor_shape, prev_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the input tensor from the previous member in pipeline.
:param input_tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:type input_tensor_shape: torch.Size
:type prev_rank: int, optional
:return: The input tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(recv_prev=True,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
def recv_backward(output_grad_shape, next_rank=None, dtype=torch.float, scatter_gather_tensors=False):
"""Receives the grad tensor from the next member in pipeline.
:param output_grad_shape: The shape of the tensor to be recieved
:param next_rank: The rank of the source of the tensor
:type output_grad_shape: torch.Size
:type next_rank: int, optional
:return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(recv_next=True,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
def send_forward(output_tensor, next_rank=None, scatter_gather_tensors=False):
"""Sends the input tensor to the next member in pipeline.
:param output_tensor: Tensor to be sent
:param next_rank: The rank of the recipient of the tensor
:type output_tensor: :class:`torch.Tensor`
:type next_rank: int, optional
"""
if not gpc.is_pipeline_last_stage():
_communicate(tensor_send_next=output_tensor,
next_rank=next_rank,
scatter_gather_tensors=scatter_gather_tensors)
def send_backward(input_tensor_grad, prev_rank=None, scatter_gather_tensors=False):
"""Sends the grad tensor to the previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param prev_rank: The rank of the recipient of the tensor
:type input_tensor_grad: :class:`torch.Tensor`
:type prev_rank: int, optional
"""
if not gpc.is_pipeline_first_stage():
_communicate(tensor_send_prev=input_tensor_grad,
prev_rank=prev_rank,
scatter_gather_tensors=scatter_gather_tensors)
def send_forward_recv_backward(output_tensor,
output_grad_shape,
recv_next=True,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param output_tensor: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type output_tensor: :class:`torch.Tensor`
:type output_grad_shape: :class:`torch.Size`
:return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_pipeline_last_stage():
output_tensor_grad = None
else:
_, output_tensor_grad = _communicate(tensor_send_next=output_tensor,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
def send_backward_recv_forward(input_tensor_grad,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type input_tensor_grad: :class:`torch.Tensor`
:type input_tensor_shape: :class:`torch.Size`
:return: The input tensor in forward step
:rtype: :class:`torch.Tensor`
"""
if gpc.is_pipeline_first_stage():
input_tensor = None
else:
input_tensor, _ = _communicate(tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
def send_forward_recv_forward(output_tensor,
input_tensor_shape,
recv_prev=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the
next member in pipeline, while recieves the input tensor from the
previous member in pipeline.
:param output_tensor: Tensor to be sent
:param input_tensor_shape: The shape of the tensor to be recieved
:type output_tensor: :class:`torch.Tensor`
:type input_tensor_shape: :class:`torch.Size`
:return: The input tensor in forward step
:rtype: :class:`torch.Tensor`
"""
input_tensor, _ = _communicate(tensor_send_next=output_tensor,
recv_prev=recv_prev,
recv_prev_shape=input_tensor_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor
def send_backward_recv_backward(input_tensor_grad,
output_grad_shape,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the grad tensor to the
previous member in pipeline, while recieves the grad tensor from the
next member in pipeline.
:param input_tensor_grad: Tensor to be sent
:param output_grad_shape: The shape of the tensor to be recieved
:type input_tensor_grad: :class:`torch.Tensor`
:type output_grad_shape: :class:`torch.Size`
:return: The grad of output tensor in forward step
:rtype: :class:`torch.Tensor`
"""
_, output_tensor_grad = _communicate(tensor_send_prev=input_tensor_grad,
recv_next=recv_next,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return output_tensor_grad
def send_forward_backward_recv_forward_backward(output_tensor,
input_tensor_grad,
input_tensor_shape,
output_grad_shape,
recv_prev=True,
recv_next=True,
prev_rank=None,
next_rank=None,
dtype=torch.float,
scatter_gather_tensors=False):
"""Batched communication operation. Sends the input tensor to the next and
the grad tensor to the previous, while recieves the grad tensor from the
next and the input tensor from the previous.
:param output_tensor: Tensor sent to the next
:param input_tensor_grad: Tensor sent to the previous
:param input_tensor_shape: The shape of the tensor recieved from the previous
:param output_grad_shape: The shape of the tensor recieved from the next
:type output_tensor: :class:`torch.Tensor`
:type input_tensor_grad: :class:`torch.Tensor`
:type input_tensor_shape: :class:`torch.Size`
:type output_grad_shape: :class:`torch.Size`
:return: (the input tensor in forward step, the grad of output tensor in forward step)
:rtype: (Tensor, Tensor)
"""
input_tensor, output_tensor_grad = _communicate(
tensor_send_next=output_tensor,
tensor_send_prev=input_tensor_grad,
recv_prev=recv_prev,
recv_next=recv_next,
recv_prev_shape=input_tensor_shape,
recv_next_shape=output_grad_shape,
prev_rank=prev_rank,
next_rank=next_rank,
dtype=dtype,
scatter_gather_tensors=scatter_gather_tensors)
return input_tensor, output_tensor_grad
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device, synchronize
def ring_forward(tensor_send_next: torch.Tensor, parallel_mode: ParallelMode):
"""Sends a tensor to the next member and recieves a tensor from the previous member.
This function returns the recieved tensor from the previous member.
:param tensor_send_next: Tensor sent to next member
:param parallel_mode: Parallel group mode used in this communication
:type tensor_send_next: :class:`torch.Tensor`
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:return: The tensor recieved from the previous
:rtype: :class:`torch.Tensor`
"""
buffer_shape = tensor_send_next.size()
ops = []
current_rank = gpc.get_global_rank()
tensor_recv_prev = torch.empty(buffer_shape,
requires_grad=True,
device=get_current_device(),
dtype=tensor_send_next.dtype)
# send to next rank
send_next_op = torch.distributed.P2POp(
torch.distributed.isend, tensor_send_next,
gpc.get_next_global_rank(parallel_mode))
ops.append(send_next_op)
# receive from prev rank
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv, tensor_recv_prev,
gpc.get_prev_global_rank(parallel_mode))
ops.append(recv_prev_op)
if current_rank % 2 == 0:
ops = ops[::-1]
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
synchronize()
return tensor_recv_prev
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import get_current_device
def send_tensor_meta(tensor, need_meta=True, next_rank=None):
"""Sends tensor meta information before sending a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be sent before communications. This function
synchronizes with :func:`recv_tensor_meta`.
:param tensor: Tensor to be sent
:param need_meta: If False, meta information won't be sent
:param next_rank: The rank of the next member in pipeline parallel group
:type tensor: Tensor
:type need_meta: bool, optional
:type next_rank: int
:return: False
:rtype: bool
"""
if need_meta:
if next_rank is None:
next_rank = gpc.get_next_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
send_shape = torch.tensor(tensor.size(), **tensor_kwargs)
send_ndims = torch.tensor(len(tensor.size()), **tensor_kwargs)
dist.send(send_ndims, next_rank)
dist.send(send_shape, next_rank)
return False
def recv_tensor_meta(tensor_shape, prev_rank=None):
"""Recieves tensor meta information before recieving a specific tensor.
Since the recipient must know the shape of the tensor in p2p communications,
meta information of the tensor should be recieved before communications. This function
synchronizes with :func:`send_tensor_meta`.
:param tensor_shape: The shape of the tensor to be recieved
:param prev_rank: The rank of the source of the tensor
:type tensor_shape: torch.Size
:type prev_rank: int, optional
:return: The shape of the tensor to be recieved
:rtype: torch.Size
"""
if tensor_shape is None:
if prev_rank is None:
prev_rank = gpc.get_prev_global_rank(ParallelMode.PIPELINE)
tensor_kwargs = {'dtype': torch.long, 'device': get_current_device()}
recv_ndims = torch.empty((), **tensor_kwargs)
dist.recv(recv_ndims, prev_rank)
recv_shape = torch.empty(recv_ndims, **tensor_kwargs)
dist.recv(recv_shape, prev_rank)
tensor_shape = torch.Size(recv_shape)
return tensor_shape
def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
"""Break a tensor into equal 1D chunks.
:param tensor: Tensor to be splitted before communication
:param new_buffer: Whether uses a new buffer to store sliced tensor
:type tensor: torch.Tensor
:type new_buffer: bool, optional
:return splitted_tensor: The splitted tensor
:rtype splitted_tensor: torch.Tensor
"""
partition_size = torch.numel(tensor) // gpc.get_world_size(ParallelMode.PARALLEL_1D)
start_index = partition_size * gpc.get_local_rank(ParallelMode.PARALLEL_1D)
end_index = start_index + partition_size
if new_buffer:
data = torch.empty(partition_size, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
data.copy_(tensor.view(-1)[start_index:end_index])
else:
data = tensor.view(-1)[start_index:end_index]
return data
def gather_split_1d_tensor(tensor):
"""Opposite of above function, gather values from model parallel ranks.
:param tensor: Tensor to be gathered after communication
:type tensor: torch.Tensor
:return gathered: The gathered tensor
:rtype gathered: torch.Tensor
"""
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
numel = torch.numel(tensor)
numel_gathered = world_size * numel
gathered = torch.empty(numel_gathered, dtype=tensor.dtype,
device=torch.cuda.current_device(),
requires_grad=False)
chunks = [gathered[i*numel:(i+1)*numel] for i in range(world_size)]
dist.all_gather(chunks, tensor, group=gpc.get_group(ParallelMode.PARALLEL_1D))
return gathered
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
ALLOWED_MODES = [None, '1d', '2d', '2.5d', '3d', 'sequence']
TENSOR_PARALLEL_MODE = 'tensor_parallel_mode'
# intializer
INITIALIZER_MAPPING = {
'data': 'Initializer_Data',
'tensor': 'Initializer_Tensor',
'pipeline': 'Initializer_Pipeline',
'embedding': 'Initializer_Embedding',
'1d': 'Initializer_1D',
'2d': 'Initializer_2D',
'2.5d': 'Initializer_2p5D',
'3d': 'Initializer_3D',
'sequence': 'Initializer_Sequence',
'model': 'Initializer_Model',
'moe': 'Initializer_Moe'
}
# 3D parallelism groups
INPUT_GROUP_3D = 'input_group_3d'
WEIGHT_GROUP_3D = 'weight_group_3d'
OUTPUT_GROUP_3D = 'output_group_3d'
# Attributes of tensor parallel parameters
IS_TENSOR_PARALLEL = 'is_tensor_parallel'
NUM_PARTITIONS = 'num_partitions'
TENSOR_PARALLEL_ATTRIBUTES = [IS_TENSOR_PARALLEL, NUM_PARTITIONS]
from .config import Config, ConfigException
from .parallel_context import ParallelContext
from .parallel_mode import ParallelMode
from .process_group_initializer import *
from .random import *
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import inspect
import sys
from importlib.machinery import SourceFileLoader
from pathlib import Path
from colossalai.logging import get_dist_logger
class Config(dict):
"""This is a wrapper class for dict objects so that values of which can be
accessed as attributes.
:param config: The dict object to be wrapped
:type config: dict
"""
def __init__(self, config: dict = None):
if config is not None:
for k, v in config.items():
self._add_item(k, v)
def __missing__(self, key):
raise KeyError(key)
def __getattr__(self, key):
try:
value = super(Config, self).__getitem__(key)
return value
except KeyError:
raise AttributeError(key)
def __setattr__(self, key, value):
super(Config, self).__setitem__(key, value)
def _add_item(self, key, value):
if isinstance(value, dict):
self.__setattr__(key, Config(value))
else:
self.__setattr__(key, value)
def update(self, config):
assert isinstance(config, (Config, dict)), 'can only update dictionary or Config objects.'
for k, v in config.items():
self._add_item(k, v)
return self
@staticmethod
def from_file(filename: str):
"""Reads a python file and constructs a corresponding :class:`Config` object.
:param filename: Name of the file to construct the return object
:type filename: str
:raises AssertionError: Raises an AssertionError if the file does not exist, or the file
is not .py file
:return: A :class:`Config` object constructed with information in the file
:rtype: :class:`Config`
"""
# check config path
if isinstance(filename, str):
filepath = Path(filename).absolute()
elif isinstance(filename, Path):
filepath = filename.absolute()
assert filepath.exists(), f'{filename} is not found, please check your configuration path'
# check extension
extension = filepath.suffix
assert extension == '.py', 'only .py files are supported'
# import the config as module
remove_path = False
if filepath.parent not in sys.path:
sys.path.insert(0, (filepath))
remove_path = True
module_name = filepath.stem
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
module = source_file.load_module()
# load into config
config = Config()
for k, v in module.__dict__.items():
if k.startswith('__') or inspect.ismodule(v) or inspect.isclass(v):
continue
else:
config._add_item(k, v)
logger = get_dist_logger()
logger.debug('variables which starts with __, is a module or class declaration are omitted in config file')
# remove module
del sys.modules[module_name]
if remove_path:
sys.path.pop(0)
return config
class ConfigException(Exception):
pass
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import os
import random
from typing import Union
import numpy as np
import torch
import torch.distributed as dist
from colossalai.constants import ALLOWED_MODES, INITIALIZER_MAPPING
from colossalai.context.config import Config
from colossalai.global_variables import moe_env
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.logging import get_dist_logger
from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
class ParallelContext:
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
"""
__instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
self._world_sizes = dict()
self._groups = dict()
self._ranks_in_group = dict()
# load config from file
self._config = None
# default 3D parallel args, will be overwritten during process group intialization
self.world_size = 1
self.data_parallel_size = 1
self.pipeline_parallel_size = 1
self.tensor_parallel_size = 1
self.virtual_pipeline_parallel_size = None
self.virtual_pipeline_parallel_rank = None
# logging
self._verbose = False
self._logger = get_dist_logger()
@property
def config(self):
return self._config
@property
def verbose(self):
return self._verbose
@verbose.setter
def verbose(self, verbose_: bool):
self._verbose = verbose_
def load_config(self, config: Union[dict, str]):
"""Loads the configuration from either a dict or a file.
:param config: Either a dict containing the configuration information or the filename
of a file containing the configuration information
:type config: dict or str
:raises TypeError: Raises a TypeError if `config` is neither a dict or a str
"""
if isinstance(config, str):
self._config = Config.from_file(config)
elif isinstance(config, dict):
self._config = Config(config)
else:
raise TypeError("Invalid type for config, only dictionary or string is supported")
@staticmethod
def _check_parallel_mode(parallel_mode: ParallelMode):
assert isinstance(parallel_mode, ParallelMode)
def get_global_rank(self):
"""Returns the global rank of the current device.
:return: The global rank of the current device
:rtype: int
"""
return self._global_ranks[ParallelMode.GLOBAL]
def add_global_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the global rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param rank: The rank to be added
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._global_ranks[parallel_mode] = rank
def get_local_rank(self, parallel_mode: ParallelMode):
"""Returns the local rank of the current device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The local rank of the current device for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
return self._local_ranks[parallel_mode]
def add_local_rank(self, parallel_mode: ParallelMode, rank: int):
"""Adds the local rank of the current device for `parallel_mode` to the context.
:param parallel_mode: The parallel mode for the rank
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param rank: The rank to be added
:type rank: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._local_ranks[parallel_mode] = rank
def get_next_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the next device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The global rank of the next device for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
# get rank and world size
local_rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
ranks_in_group = self.get_ranks_in_group(parallel_mode)
return ranks_in_group[(local_rank + 1) % world_size]
def get_prev_global_rank(self, parallel_mode: ParallelMode):
"""Returns the global rank of the previous device.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The global rank of the previous device for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
# get rank and world size
local_rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
ranks_in_group = self.get_ranks_in_group(parallel_mode)
return ranks_in_group[(local_rank - 1) % world_size]
def is_first_rank(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether the current device is the first one
among its group for `parallel_mode`
:rtype: bool
"""
rank = self.get_local_rank(parallel_mode)
return rank == 0
def is_last_rank(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether the current device is the last one
among its group for `parallel_mode`
:rtype: bool
"""
rank = self.get_local_rank(parallel_mode)
world_size = self.get_world_size(parallel_mode)
return rank == world_size - 1
def is_pipeline_first_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != 0:
return False
return self.is_first_rank(ParallelMode.PIPELINE)
def is_pipeline_last_stage(self, ignore_virtual=False):
if not ignore_virtual:
if self.virtual_pipeline_parallel_size is not None and self.virtual_pipeline_parallel_rank != self.virtual_pipeline_parallel_size - 1:
return False
return self.is_last_rank(ParallelMode.PIPELINE)
def get_world_size(self, parallel_mode: ParallelMode):
"""Returns the world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The world size for `parallel_mode`
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
return self._world_sizes[parallel_mode]
def add_world_size(self, parallel_mode: ParallelMode, world_size: int):
"""Adds world size for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param world_size: The world size to be added
:type world_size: int
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._world_sizes[parallel_mode] = world_size
def get_group(self, parallel_mode: ParallelMode):
"""Returns the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: The group of the current device for `parallel_mode`
:rtype: torch.distributed.ProcessGroup
"""
self._check_parallel_mode(parallel_mode)
return self._groups[parallel_mode]
def add_group(self, parallel_mode: ParallelMode, group: dist.ProcessGroup):
"""Adds the group of the current device for `parallel_mode`.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param group: The group to be added
:type group: torch.distributed.ProcessGroup
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._groups[parallel_mode] = group
def get_ranks_in_group(self, parallel_mode: ParallelMode):
"""Returns the rank of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
:return: the rank of the current device for `parallel_mode` in the group
:rtype: int
"""
self._check_parallel_mode(parallel_mode)
return self._ranks_in_group[parallel_mode]
def add_ranks_in_group(self, parallel_mode: ParallelMode, ranks: list):
"""Adds the ranks of the current device for `parallel_mode` in the group.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:param ranks: List of ranks to be added
:type ranks: list
:raises AssertionError: Raises an AssertionError if `parallel_mode` is not an instance
of :class:`colossalai.context.ParallelMode`
"""
self._check_parallel_mode(parallel_mode)
self._ranks_in_group[parallel_mode] = ranks
def init_global_dist(self,
rank: int,
world_size: int,
backend: str,
host: str,
port: int
):
"""Initializes the global distributed environment
:param rank: rank for the default process group
:type rank: int
:param world_size: world size of the default process group
:type world_size: int
:param host: the master address for distributed training
:type host: str
:param port: the master port for distributed training
:type port: str
:param backend: backend for torch.distributed
:type backend: str
"""
# initialize the default process group
init_method = f'tcp://{host}:{port}'
dist.init_process_group(rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method)
# None will give the default global process group for pytorch dist operations
self._register_dist(rank, world_size, None,
list(range(world_size)), ParallelMode.GLOBAL)
self.add_global_rank(ParallelMode.GLOBAL, rank)
def _register_dist(self, local_rank, world_size,
process_group, ranks_in_group, mode):
self.add_local_rank(mode, local_rank)
self.add_world_size(mode, world_size)
self.add_group(mode, process_group)
self.add_ranks_in_group(mode, ranks_in_group)
def check_sanity(self):
"""Checks sanity of the parallel context.
:raises AssertionError: Raises an AssertionError if the world size does not equal to the product
of data paralle size, pipeline parallel size and tensor parallel size
"""
dps = self.data_parallel_size
pps = self.pipeline_parallel_size
tps = self.tensor_parallel_size
ws = self.world_size
assert ws == dps * pps * \
tps, f"Expected the world size {ws} to be equal to data parallel size ({dps}) * pipeline parallel size ({pps}) * tensor parallel size ({tps})"
def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str):
if key in config:
ele = config[key]
if isinstance(ele, int):
setattr(self, attr_name, ele)
elif isinstance(ele, dict):
setattr(self, attr_name, ele['size'])
else:
raise NotImplementedError(
f"Parallel configuration does not support this kind of argument, please use int or dict"
)
def init_parallel_groups(self):
"""Initializes the parallel groups.
:raises AssertionError: Raises an AssertionError if the field paralle is not present in the config file
"""
# get rank and world size
rank = self.get_global_rank()
world_size = self.get_world_size(ParallelMode.GLOBAL)
self.world_size = world_size
# set parallel size as attributes for global context
parallel_config = self.config.get('parallel', None)
if parallel_config is not None:
self._set_parallel_size_from_config(parallel_config, 'pipeline', 'pipeline_parallel_size')
self._set_parallel_size_from_config(parallel_config, 'tensor', 'tensor_parallel_size')
# the user should not set the data parallel size manually
# instead, it should be calculated based on other parallel config
self.data_parallel_size = self.world_size // (self.pipeline_parallel_size * self.tensor_parallel_size)
# get the tensor parallel mode and check
tensor_parallel_mode = None
if parallel_config is not None and 'tensor' in parallel_config and 'mode' in parallel_config['tensor']:
tensor_parallel_mode = parallel_config['tensor']['mode']
assert tensor_parallel_mode in ALLOWED_MODES, f"mode in the parallel config must be set to one of {ALLOWED_MODES}"
env.mode = tensor_parallel_mode
self.check_sanity()
pg_init = []
# LSG: init data parallel process group for compatibility with other parallel module such as zero
pg_init.append(dict(type=INITIALIZER_MAPPING['data']))
# LSG: init model parallel process group for compatibility with amp and clip grad
pg_init.append(dict(type=INITIALIZER_MAPPING['model']))
if self.pipeline_parallel_size > 1:
pg_init.append(dict(type=INITIALIZER_MAPPING['pipeline']))
pg_init.append(dict(type=INITIALIZER_MAPPING['tensor']))
# init specific tensor parallel group
if tensor_parallel_mode is not None:
tensor_parallel_cfg = parallel_config['tensor'].copy()
# remove duplicate parameters
tensor_parallel_cfg.pop('mode')
tensor_parallel_cfg.pop('size')
# add this config to initialize later
pg_init.append(dict(type=INITIALIZER_MAPPING[tensor_parallel_mode.lower()], **tensor_parallel_cfg))
# initialization for moe environment
if parallel_config is not None and 'moe' in parallel_config:
param = parallel_config['moe']
assert 'size' in param, "Moe model parallel size should be given"
moe_env.setup(param['size'])
pg_init.append(dict(type=INITIALIZER_MAPPING['moe']))
# run initialization of different process groups
for initializer_cfg in pg_init:
cfg = initializer_cfg.copy()
initializer_type = cfg.pop('type')
initializer = DIST_GROUP_INITIALIZER.get_module(initializer_type)(
rank, world_size, self.config,
self.data_parallel_size,
self.pipeline_parallel_size,
self.tensor_parallel_size,
**cfg)
parallel_setting = initializer.init_dist_group()
if isinstance(parallel_setting, list):
for args in parallel_setting:
self._register_dist(*args)
else:
self._register_dist(*parallel_setting)
def is_initialized(self, parallel_mode: ParallelMode):
"""Returns a boolean value indicating whether `parallel_mode` is initialized
in the current system.
:param parallel_mode: The chosen parallel mode
:type parallel_mode: :class:`colossalai.context.ParallelMode`
:return: a boolean value indicating whether `parallel_mode` is initialized
in the current system
:rtype: bool
"""
return parallel_mode in self._groups
def destroy(self):
"""Destroys the current distributed parallel environment.
"""
for mode, group in self._groups.items():
if mode is not ParallelMode.GLOBAL:
dist.destroy_process_group(group)
# destroy global process group
dist.destroy_process_group()
def set_device(self, device_ordinal: int = None):
"""Sets distributed processes to be bound to devices.
:param device_ordinal: the device id to be bound to
:type device_ordinal: int, optional
"""
global_rank = self.get_global_rank()
if device_ordinal is None:
devices_per_node = torch.cuda.device_count()
device_ordinal = global_rank % devices_per_node
torch.cuda.set_device(device_ordinal)
if self._verbose:
self._logger.info(f'process rank {global_rank} is bound to device {device_ordinal}')
def set_seed(self, seed: int):
"""Sets seeds for all random libraries.
:param seed: seed for random states
:type seed: int
"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
global_rank = self.get_global_rank()
if torch.cuda.is_available():
# create random seed for different parallel modes
# data parallel seed are kept the same
parallel_seed = seed
add_seed(ParallelMode.DATA, parallel_seed)
# model parallel seeds are different across ranks
pipeline_offset = self._local_ranks.get(ParallelMode.PIPELINE, 0)
# add seed for data parallel and tensor parallel only
if self.is_initialized(ParallelMode.TENSOR):
tp_rank = self.get_local_rank(ParallelMode.TENSOR)
# 100 is only to increase the diff in seeds between pipeline stages
tp_rank_with_offset = tp_rank + pipeline_offset * 1024
tp_seed = seed + tp_rank_with_offset
add_seed(ParallelMode.TENSOR, tp_seed)
set_mode(ParallelMode.DATA)
seeds = get_seeds()
seed_str = ', '.join([f'{k}: {v}' for k, v in seeds.items()])
if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, {seed_str},"
f"the default parallel seed is {ParallelMode.DATA}.")
else:
if self._verbose:
self._logger.info(
f"initialized seed on rank {global_rank}, "
f"numpy: {seed}, python random: {seed}, pytorch: {seed}",
ranks=[0])
self._logger.info(
'WARNING: CUDA is not available, thus CUDA RNG cannot be used to track CUDA random number states',
ranks=[0])
def set_virtual_pipeline_parallel_size(self, size):
self.virtual_pipeline_parallel_size = size
def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from enum import Enum
# parallel modes
class ParallelMode(Enum):
"""This is an enumeration class containing all possible parallel modes.
"""
GLOBAL = 'global'
# common parallel
DATA = 'data'
# model parallel - containing tensor and pipeline parallel groups
# this is added to facilitate amp and grad clipping in hybrid parallel
MODEL = 'model'
# pipeline parallel
PIPELINE = 'pipe'
# containing all ranks in tensor parallel
TENSOR = 'tensor'
# sequence parallel
SEQUENCE = 'sequence'
SEQUENCE_DP = 'sequence_dp'
# 1D Parallel
PARALLEL_1D = '1d'
# 2D parallel
PARALLEL_2D_ROW = '2d_row'
PARALLEL_2D_COL = '2d_col'
# 3D parallel
PARALLEL_3D_INPUT = '3d_input'
PARALLEL_3D_WEIGHT = '3d_weight'
PARALLEL_3D_OUTPUT = '3d_output'
# 2.5D parallel
PARALLEL_2P5D_ROW = '2p5d_row'
PARALLEL_2P5D_COL = '2p5d_col'
PARALLEL_2P5D_DEP = '2p5d_dep'
PARALLEL_2P5D_XZ = '2p5d_xz'
# MOE parallel
MOE_DATA = 'moe_data'
MOE_MODEL = 'moe_model'
from .initializer_1d import Initializer_1D
from .initializer_2d import Initializer_2D
from .initializer_2p5d import Initializer_2p5D
from .initializer_3d import Initializer_3D
from .initializer_data import Initializer_Data
from .initializer_pipeline import Initializer_Pipeline
from .initializer_sequence import Initializer_Sequence
from .initializer_tensor import Initializer_Tensor
from .initializer_model import Initializer_Model
from .initializer_moe import Initializer_Moe
from .process_group_initializer import ProcessGroupInitializer
__all__ = [
'Initializer_Tensor', 'Initializer_Sequence', 'Initializer_Pipeline',
'Initializer_Data', 'Initializer_2p5D', 'Initializer_2D', 'Initializer_3D',
'Initializer_1D', 'ProcessGroupInitializer', 'Initializer_Model',
'Initializer_Moe'
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment