Commit 42254422 authored by Hongxin Liu's avatar Hongxin Liu
Browse files

[pipeline] add stage manager (#4093)

* [pipeline] add stage manager

* [test] add pipeline stage manager test

* [pipeline] add docstring for stage manager
parent 5e1a9d48
from contextlib import contextmanager
from typing import Dict, List, Optional, Tuple
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.cluster import ProcessGroupMesh
class PipelineStageManager:
"""PipelineStageManager is a helper class to manage pipeline stages.
Args:
pg_mesh (ProcessGroupMesh): Process group mesh.
pipeline_axis (int): The axis along which the pipeline is constructed.
Attributes:
num_stages (int): Number of stages in the pipeline.
stage (int): The current stage.
num_virtual_stages (int): Number of virtual stages in the pipeline.
virtual_stage (int): The current virtual stage.
"""
def __init__(self, pg_mesh: ProcessGroupMesh, pipeline_axis: int) -> None:
self.pg_mesh = pg_mesh
self.pipeline_axis = pipeline_axis
self.num_virtual_stages: Optional[int] = None
self.virtual_stage: Optional[int] = None
self.prev_rank: Optional[Tuple[int, ...]] = None
self.next_rank: Optional[Tuple[int, ...]] = None
self.p2p_groups: Dict[Tuple[int, int], ProcessGroup] = {}
# init prev and next coord
coord = self.pg_mesh.coordinate()
if self.stage > 0:
prev_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] - 1,) + coord[self.pipeline_axis + 1:]
self.prev_rank = self.pg_mesh.ravel(prev_coord, self.pg_mesh.shape)
if self.stage < self.num_stages - 1:
next_coord = coord[: self.pipeline_axis] + \
(coord[self.pipeline_axis] + 1,) + coord[self.pipeline_axis + 1:]
self.next_rank = self.pg_mesh.ravel(next_coord, self.pg_mesh.shape)
# init p2p process groups
stages = list(range(self.num_stages))
for prev, cur in zip(stages[:-1], stages[1:]):
group = self.pg_mesh.get_group_along_axis(self.pipeline_axis, [prev, cur])
if self.stage in [prev, cur]:
ranks_in_group = self.pg_mesh.get_ranks_in_group(group)
self.p2p_groups[tuple(ranks_in_group)] = group
def is_first_stage(self, virtual: bool = False) -> bool:
"""Is the current stage the first stage.
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
Returns:
bool: Whether the current stage is the first stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == 0
return self.stage == 0
def is_last_stage(self, virtual: bool = False) -> bool:
"""Is the current stage the last stage.
Args:
virtual (bool, optional): Whether to consider virtual stages. Defaults to False.
Returns:
bool: Whether the current stage is the last stage.
"""
if virtual:
assert self.num_virtual_stages is not None
return self.virtual_stage == self.num_virtual_stages - 1
return self.stage == self.num_stages - 1
@property
def num_stages(self) -> int:
"""Number of stages in the pipeline.
Returns:
int: Number of stages in the pipeline.
"""
return self.pg_mesh.size(self.pipeline_axis)
@property
def stage(self) -> int:
"""Current stage.
Returns:
int: Current stage.
"""
return self.pg_mesh.coordinate(self.pipeline_axis)
def get_rank(self) -> int:
"""Get the rank of the current process.
Returns:
int: Rank of the current process.
"""
return dist.get_rank()
def get_prev_rank(self) -> int:
"""Get the rank of the previous stage.
Returns:
int: Rank of the previous stage.
"""
assert not self.is_first_stage(), "Cannot get previous rank in the first stage."
return self.prev_rank
def get_next_rank(self) -> int:
"""Get the rank of the next stage.
Returns:
int: Rank of the next stage.
"""
assert not self.is_last_stage(), "Cannot get next rank in the last stage."
return self.next_rank
def set_num_virtual_stages(self, num_virtual_stages: int) -> None:
"""Set the number of virtual stages.
Args:
num_virtual_stages (int): Number of virtual stages.
"""
self.num_virtual_stages = num_virtual_stages
def set_virtual_stage(self, virtual_stage: int) -> None:
"""Set the virtual stage.
Args:
virtual_stage (int): Virtual stage.
"""
self.virtual_stage = virtual_stage
@contextmanager
def switch_virtual_stage(self, virtual_stage: int) -> None:
"""A context manager to switch virtual stage.
Args:
virtual_stage (int): Target virtual stage.
"""
old_stage = self.virtual_stage
try:
self.set_virtual_stage(virtual_stage)
yield
finally:
self.set_virtual_stage(old_stage)
def get_p2p_process_group(self, first_rank: int, second_rank: int) -> ProcessGroup:
"""Get the p2p process group between two ranks. The order of the two ranks does not matter.
Args:
first_rank (int): The first rank.
second_rank (int): The second rank.
Returns:
ProcessGroup: P2P process group between the two ranks.
"""
if first_rank > second_rank:
first_rank, second_rank = second_rank, first_rank
return self.p2p_groups[(first_rank, second_rank)]
def init_process_group_by_stages(self, stages: List[int]) -> ProcessGroup:
"""Get the process group of the given stages.
Args:
stages (List[int]): List of stages.
Returns:
ProcessGroup: Process group of the given stages.
"""
return self.pg_mesh.get_group_along_axis(self.pipeline_axis, stages)
import pytest
import torch.distributed as dist
import colossalai
from colossalai.cluster import ProcessGroupMesh
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.testing import spawn
def check_stage_manager():
DP_DIM, PP_DIM = 0, 1
DP_SIZE, PP_SIZE = 2, 2
RANK_TO_COORDINATE = {
0: (0, 0),
1: (0, 1),
2: (1, 0),
3: (1, 1),
}
PP_RANKS_IN_GROUP = {
0: [0, 1],
1: [0, 1],
2: [2, 3],
3: [2, 3],
}
pg_mesh = ProcessGroupMesh(DP_SIZE, PP_SIZE)
stage_manager = PipelineStageManager(pg_mesh, PP_DIM)
rank = dist.get_rank()
# check stage info
assert stage_manager.num_stages == PP_SIZE
assert stage_manager.stage == RANK_TO_COORDINATE[rank][PP_DIM]
# check is_first_stage
ranks_in_group = PP_RANKS_IN_GROUP[rank]
is_first_stage = ranks_in_group.index(rank) == 0
assert stage_manager.is_first_stage() == is_first_stage
# check is_last_stage
is_last_stage = ranks_in_group.index(rank) == len(ranks_in_group) - 1
assert stage_manager.is_last_stage() == is_last_stage
# check prev rank
if not is_first_stage:
prev_rank = ranks_in_group[ranks_in_group.index(rank) - 1]
assert stage_manager.get_prev_rank() == prev_rank
# check next rank
if not is_last_stage:
next_rank = ranks_in_group[ranks_in_group.index(rank) + 1]
assert stage_manager.get_next_rank() == next_rank
# check virtual stage
stage_manager.set_num_virtual_stages(PP_SIZE * 2)
assert stage_manager.num_virtual_stages == PP_SIZE * 2
stage_manager.set_virtual_stage(stage_manager.stage * 2)
assert stage_manager.virtual_stage == stage_manager.stage * 2
with stage_manager.switch_virtual_stage(stage_manager.stage * 2 + 1):
assert stage_manager.virtual_stage == stage_manager.stage * 2 + 1
assert stage_manager.virtual_stage == stage_manager.stage * 2
# check p2p groups
for prev, cur in zip(ranks_in_group[:-1], ranks_in_group[1:]):
if rank in [prev, cur]:
group = stage_manager.get_p2p_process_group(prev, cur)
dist.barrier(group=group)
# check stage groups
pg_mesh = ProcessGroupMesh(4)
stage_manager = PipelineStageManager(pg_mesh, 0)
group = stage_manager.init_process_group_by_stages([0, 2])
if rank in [0, 2]:
dist.barrier(group=group)
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host='localhost')
check_stage_manager()
@pytest.mark.dist
def test_process_group_mesh():
spawn(run_dist, 4)
if __name__ == '__main__':
test_process_group_mesh()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment