"docs/vscode:/vscode.git/clone" did not exist on "683c3fe996c5cf2c22abe903b1361da3d8ad2fc3"
Commit aa140059 authored by Shenggan's avatar Shenggan
Browse files

use colossalai as distributed core

parent 3b0299c2
......@@ -48,12 +48,11 @@ from fastfold.model import Evoformer
evoformer_layer = Evoformer()
```
If you want to use Dynamic Axial Parallelism, add a line of initialize with `fastfold.distributed.init_dap` after `torch.distributed.init_process_group`.
If you want to use Dynamic Axial Parallelism, add a line of initialize with `fastfold.distributed.init_dap`.
```python
from fastfold.distributed import init_dap
torch.distributed.init_process_group(backend='nccl', init_method='env://')
init_dap(args.dap_size)
```
......
from .core import (init_dap, dap_is_initialized, get_tensor_model_parallel_group,
get_data_parallel_group, get_tensor_model_parallel_world_size,
get_tensor_model_parallel_rank, get_data_parallel_world_size,
get_data_parallel_rank, get_tensor_model_parallel_src_rank)
from .core import init_dap
from .comm import (_reduce, _split, _gather, copy, scatter, reduce, gather, col_to_row, row_to_col)
__all__ = [
'init_dap', 'dap_is_initialized', 'get_tensor_model_parallel_group',
'get_data_parallel_group', 'get_tensor_model_parallel_world_size',
'get_tensor_model_parallel_rank', 'get_data_parallel_world_size', 'get_data_parallel_rank',
'get_tensor_model_parallel_src_rank', '_reduce', '_split', '_gather', 'copy', 'scatter',
'reduce', 'gather', 'col_to_row', 'row_to_col'
'init_dap', '_reduce', '_split', '_gather', 'copy', 'scatter', 'reduce', 'gather', 'col_to_row',
'row_to_col'
]
\ No newline at end of file
......@@ -4,8 +4,9 @@ import torch
import torch.distributed as dist
from torch import Tensor
from .core import (get_tensor_model_parallel_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from .core import ensure_divisibility
......@@ -15,42 +16,50 @@ def divide(numerator, denominator):
def _reduce(tensor: Tensor) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
dist.all_reduce(tensor,
op=dist.ReduceOp.SUM,
group=get_tensor_model_parallel_group(),
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
return tensor
def _split(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
split_size = divide(tensor.shape[dim], get_tensor_model_parallel_world_size())
split_size = divide(tensor.shape[dim], gpc.get_world_size(ParallelMode.TENSOR))
tensor_list = torch.split(tensor, split_size, dim=dim)
output = tensor_list[get_tensor_model_parallel_rank()].contiguous()
output = tensor_list[gpc.get_local_rank(ParallelMode.TENSOR)].contiguous()
return output
def _gather(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
if dim == 1:
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
dist.all_gather(list(tensor_list), tensor, group=get_tensor_model_parallel_group(), async_op=False)
tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
dist.all_gather(list(tensor_list),
tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
else:
tensor_list = [torch.empty_like(tensor) for _ in range(get_tensor_model_parallel_world_size())]
dist.all_gather(tensor_list, tensor, group=get_tensor_model_parallel_group(), async_op=False)
tensor_list = [
torch.empty_like(tensor) for _ in range(gpc.get_world_size(ParallelMode.TENSOR))
]
dist.all_gather(tensor_list,
tensor,
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
output = torch.cat(tensor_list, dim=dim)
return output
......@@ -135,28 +144,28 @@ class Gather(torch.autograd.Function):
def _all_to_all(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor
split_size = divide(tensor.shape[in_dim], get_tensor_model_parallel_world_size())
split_size = divide(tensor.shape[in_dim], gpc.get_world_size(ParallelMode.TENSOR))
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
if out_dim == 1:
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
output_tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=get_tensor_model_parallel_group(),
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
else:
output_tensor_list = [torch.ones_like(tensor_) for tensor_ in input_tensor_list]
dist.all_to_all(output_tensor_list,
input_tensor_list,
group=get_tensor_model_parallel_group(),
group=gpc.get_group(ParallelMode.TENSOR),
async_op=False)
output = torch.cat(output_tensor_list, dim=out_dim)
......
......@@ -5,21 +5,23 @@ import torch
import torch.distributed as dist
from torch import Tensor
from .core import get_tensor_model_parallel_world_size, get_tensor_model_parallel_group
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from .comm import _split, divide
def _gather_async(tensor: Tensor, dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor, None
output_shape = list(tensor.shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
work = dist.all_gather(list(tensor_list),
tensor,
group=get_tensor_model_parallel_group(),
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return output, work
......@@ -45,13 +47,13 @@ class GatherAsyncOpp(torch.autograd.Function):
@staticmethod
def forward(ctx: "GatherAsyncOpp", input: Tensor) -> Tensor:
mp_size = get_tensor_model_parallel_world_size()
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
output = rearrange(input, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
@staticmethod
def backward(ctx: "GatherAsyncOpp", grad_output: Tensor) -> Tuple[Tensor]:
mp_size = get_tensor_model_parallel_world_size()
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
n, h, w, c = grad_output.shape
return grad_output.resize_(n, h * mp_size, int(w / mp_size), c)
......@@ -66,28 +68,28 @@ class GatherAsync(torch.autograd.Function):
@staticmethod
def backward(ctx: "GatherAsync", grad_output: Tensor, grad_work=None) -> Tuple[Tensor]:
if ctx.dim == 2:
mp_size = get_tensor_model_parallel_world_size()
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
n, h, w, c = grad_output.shape
grad_output.resize_(n, int(h / mp_size), w * mp_size, c)
return _split(grad_output, dim=ctx.dim), None
def _all_to_all_async(tensor: Tensor, in_dim: int = -1, out_dim: int = -1) -> Tensor:
if get_tensor_model_parallel_world_size() == 1:
if gpc.get_world_size(ParallelMode.TENSOR) == 1:
return tensor, None
split_size = divide(tensor.shape[in_dim], get_tensor_model_parallel_world_size())
split_size = divide(tensor.shape[in_dim], gpc.get_world_size(ParallelMode.TENSOR))
input_tensor_list = torch.split(tensor, split_size, dim=in_dim)
input_tensor_list = [tensor_.contiguous() for tensor_ in input_tensor_list]
output_shape = list(input_tensor_list[0].shape)
output_shape[1] *= get_tensor_model_parallel_world_size()
output_shape[1] *= gpc.get_world_size(ParallelMode.TENSOR)
output = torch.empty(output_shape, dtype=tensor.dtype, device=tensor.device)
output_tensor_list = output.chunk(get_tensor_model_parallel_world_size(), dim=1)
output_tensor_list = output.chunk(gpc.get_world_size(ParallelMode.TENSOR), dim=1)
work = dist.all_to_all(list(output_tensor_list),
input_tensor_list,
group=get_tensor_model_parallel_group(),
group=gpc.get_group(ParallelMode.TENSOR),
async_op=True)
return output, work
......@@ -114,7 +116,7 @@ class All_to_All_Async(torch.autograd.Function):
WORLD_WORK_ALL2ALL.wait()
WORLD_WORK_ALL2ALL = None
if ctx.in_dim == 2:
mp_size = get_tensor_model_parallel_world_size()
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
grad_output = rearrange(grad_output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return grad_output, None, None
......@@ -132,7 +134,7 @@ class All_to_All_Async_Opp(torch.autograd.Function):
if work:
work.wait()
if out_dim == 2:
mp_size = get_tensor_model_parallel_world_size()
mp_size = gpc.get_world_size(ParallelMode.TENSOR)
output = rearrange(output, 'n (x h) w c -> n h (x w) c', x=mp_size)
return output
......
import torch.distributed as dist
import os
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
# Intra-layer model parallel group that the current rank belongs to.
_TENSOR_MODEL_PARALLEL_GROUP = None
# These values enable us to change the mpu sizes on the fly.
_TENSOR_MODEL_PARALLEL_WORLD_SIZE = None
_TENSOR_MODEL_PARALLEL_RANK = None
import torch
import colossalai
def ensure_divisibility(numerator, denominator):
"""Ensure that numerator is divisible by the denominator."""
assert numerator % denominator == 0, '{} is not divisible by {}'.format(numerator, denominator)
def init_dap(tensor_model_parallel_size_=1):
assert dist.is_initialized()
world_size = dist.get_world_size()
rank = dist.get_rank()
# check dist config
ensure_divisibility(world_size, tensor_model_parallel_size_)
data_parallel_size_ = world_size // tensor_model_parallel_size_
# Build the data-parallel groups.
global _DATA_PARALLEL_GROUP
assert _DATA_PARALLEL_GROUP is None, \
'data parallel group is already initialized'
for i in range(tensor_model_parallel_size_):
ranks = range(i, world_size, tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_DATA_PARALLEL_GROUP = group
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, \
'tensor model parallel group is already initialized'
# Build the model-parallel groups.
for i in range(data_parallel_size_):
ranks = range(i * tensor_model_parallel_size_, (i + 1) * tensor_model_parallel_size_)
group = dist.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
if dist.get_rank() == 0:
print('> initialize tensor model parallel with size {}'.format(tensor_model_parallel_size_))
print('> initialize data parallel with size {}'.format(data_parallel_size_))
def dap_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or \
_DATA_PARALLEL_GROUP is None:
return False
return True
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, \
'intra_layer_model parallel group is not initialized'
return _TENSOR_MODEL_PARALLEL_GROUP
def get_data_parallel_group():
"""Get the data parallel group the caller rank belongs to."""
assert _DATA_PARALLEL_GROUP is not None, \
'data parallel group is not initialized'
return _DATA_PARALLEL_GROUP
def get_tensor_model_parallel_world_size():
if not dap_is_initialized():
return 1
"""Return world size for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_WORLD_SIZE
if _TENSOR_MODEL_PARALLEL_WORLD_SIZE is not None:
return _TENSOR_MODEL_PARALLEL_WORLD_SIZE
return dist.get_world_size(group=get_tensor_model_parallel_group())
def get_tensor_model_parallel_rank():
if not dap_is_initialized():
return 0
"""Return my rank for the tensor model parallel group."""
global _TENSOR_MODEL_PARALLEL_RANK
if _TENSOR_MODEL_PARALLEL_RANK is not None:
return _TENSOR_MODEL_PARALLEL_RANK
return dist.get_rank(group=get_tensor_model_parallel_group())
def set_distributed_environ(key, value):
os.environ[str(key)] = str(value)
def get_data_parallel_world_size():
"""Return world size for the data parallel group."""
return dist.get_world_size(group=get_data_parallel_group())
def init_dap(tensor_model_parallel_size_=None):
colossalai.logging.disable_existing_loggers()
if tensor_model_parallel_size_ == None:
if 'WORLD_SIZE' in os.environ:
tensor_model_parallel_size_ = int(os.environ['WORLD_SIZE'])
else:
tensor_model_parallel_size_ = 1
def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
return dist.get_rank(group=get_data_parallel_group())
if torch.torch.distributed.is_initialized():
_logger = colossalai.logging.get_dist_logger()
_logger.error(
"use fastfold.distributed.init_dap instead of torch.distributed.init_process_group!")
exit(-1)
# set distributed environ for single device launch
if 'RANK' not in os.environ:
set_distributed_environ('WORLD_SIZE', 1)
set_distributed_environ('RANK', 0)
set_distributed_environ('LOCAL_RANK', 0)
set_distributed_environ('MASTER_ADDR', "localhost")
set_distributed_environ('MASTER_PORT', 10045)
def get_tensor_model_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank = dist.get_rank()
local_world_size = get_tensor_model_parallel_world_size()
return (global_rank // local_world_size) * local_world_size
colossalai.launch_from_torch(
config={"parallel": dict(tensor=dict(size=tensor_model_parallel_size_))})
......@@ -3,10 +3,12 @@ from typing import Tuple, Optional
import torch
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from fastfold.model import MSAStack, OutProductMean, PairStack
from fastfold.distributed.comm_async import All_to_All_Async, All_to_All_Async_Opp
from fastfold.distributed.comm import gather, scatter
from fastfold.distributed import get_tensor_model_parallel_world_size
class EvoformerBlock(nn.Module):
......@@ -31,7 +33,7 @@ class EvoformerBlock(nn.Module):
_mask_trans: bool = True,
) -> Tuple[torch.Tensor, torch.Tensor]:
dap_size = get_tensor_model_parallel_world_size()
dap_size = gpc.get_world_size(ParallelMode.TENSOR)
seq_length = pair_mask.size(-1)
padding_size = (int(seq_length / dap_size) + 1) * dap_size - seq_length
......
......@@ -41,9 +41,7 @@ def main(args):
if local_rank != -1:
distributed_inference_ = True
torch.cuda.set_device(local_rank)
torch.distributed.init_process_group(backend='nccl', init_method='env://')
fastfold.distributed.init_dap(torch.distributed.get_world_size())
fastfold.distributed.init_dap()
else:
distributed_inference_ = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment