Unverified Commit f10b4b89 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer] `parallel_state`: Position Embedding (#1343)

* update

* Add comment to `destroy_model_parallel`
parent 28f8539c
......@@ -12,12 +12,22 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# TODO (mkozuki): Replace assert with RuntimeError.
# TODO (mkozuki): Sort the functions in the same order of megatron/mpu/initialize.py
"""Model and data parallel groups."""
from typing import Tuple, Optional
import torch
from apex.transformer.utils import ensure_divisibility
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
# set(megatron_mpu_initialize_funcs) - set(apex.transformer.parallel_state) =
# {
# 'get_num_layers',
# }
# Intra-layer model parallel group that the current rank belongs to.
......@@ -28,6 +38,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP = None
# Embedding group.
_EMBEDDING_GROUP = None
# Position embedding group.
_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
......@@ -44,6 +56,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS = None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS = None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage
_PIPELINE_GLOBAL_RANKS = None
......@@ -55,10 +70,10 @@ def is_unitialized():
def initialize_model_parallel(
tensor_model_parallel_size_: int = 1,
pipeline_model_parallel_size_: int = 1,
virtual_pipeline_model_parallel_size_: Optional[int] = None,
pipeline_model_parallel_split_rank_: Optional[int] = None,
tensor_model_parallel_size_: int = 1,
pipeline_model_parallel_size_: int = 1,
virtual_pipeline_model_parallel_size_: Optional[int] = None,
pipeline_model_parallel_split_rank_: Optional[int] = None,
) -> None:
"""
Initialize model data parallel groups.
......@@ -87,28 +102,46 @@ def initialize_model_parallel(
"""
# Get world size and rank. Ensure some consistencies.
assert torch.distributed.is_initialized()
world_size = torch.distributed.get_world_size()
tensor_model_parallel_size = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size = min(pipeline_model_parallel_size_, world_size)
ensure_divisibility(world_size, tensor_model_parallel_size * pipeline_model_parallel_size)
data_parallel_size = world_size // (tensor_model_parallel_size * pipeline_model_parallel_size)
world_size: int = torch.distributed.get_world_size()
tensor_model_parallel_size: int = min(tensor_model_parallel_size_, world_size)
pipeline_model_parallel_size: int = min(pipeline_model_parallel_size_, world_size)
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
raise RuntimeError(
f"`world_size` ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
)
data_parallel_size: int = world_size // (
tensor_model_parallel_size * pipeline_model_parallel_size
)
if torch.distributed.get_rank() == 0:
print("> initializing tensor model parallel with size {}".format(tensor_model_parallel_size))
print("> initializing pipeline model parallel with size {}".format(pipeline_model_parallel_size))
print("> initializing data parallel with size {}".format(data_parallel_size))
_logger.info(
"> initializing tensor model parallel with size {}".format(
tensor_model_parallel_size
)
)
_logger.info(
"> initializing pipeline model parallel with size {}".format(
pipeline_model_parallel_size
)
)
_logger.info(
"> initializing data parallel with size {}".format(data_parallel_size)
)
num_tensor_model_parallel_groups = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups = world_size // pipeline_model_parallel_size
num_data_parallel_groups = world_size // data_parallel_size
num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None:
assert pipeline_model_parallel_size_ > 2, \
'pipeline-model-parallel size should be greater than 2 with ' \
'interleaved schedule'
# assert pipeline_model_parallel_size_ > 2, (
# "pipeline-model-parallel size should be greater than 2 with "
# "interleaved schedule"
# )
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = virtual_pipeline_model_parallel_size_
_VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE = (
virtual_pipeline_model_parallel_size_
)
if pipeline_model_parallel_split_rank_ is not None:
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
......@@ -134,16 +167,23 @@ def initialize_model_parallel(
global _MODEL_PARALLEL_GROUP
assert _MODEL_PARALLEL_GROUP is None, "model parallel group is already initialized"
for i in range(data_parallel_size):
ranks = [data_parallel_group_ranks[i] for data_parallel_group_ranks in all_data_parallel_group_ranks]
ranks = [
data_parallel_group_ranks[i]
for data_parallel_group_ranks in all_data_parallel_group_ranks
]
group = torch.distributed.new_group(ranks)
if rank in ranks:
_MODEL_PARALLEL_GROUP = group
# Build the tensor model-parallel groups.
global _TENSOR_MODEL_PARALLEL_GROUP
assert _TENSOR_MODEL_PARALLEL_GROUP is None, "tensor model parallel group is already initialized"
assert (
_TENSOR_MODEL_PARALLEL_GROUP is None
), "tensor model parallel group is already initialized"
for i in range(num_tensor_model_parallel_groups):
ranks = range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
ranks = list(
range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
)
group = torch.distributed.new_group(ranks)
if rank in ranks:
_TENSOR_MODEL_PARALLEL_GROUP = group
......@@ -152,10 +192,17 @@ def initialize_model_parallel(
# (first and last rank in each pipeline model-parallel group).
global _PIPELINE_MODEL_PARALLEL_GROUP
global _PIPELINE_GLOBAL_RANKS
assert _PIPELINE_MODEL_PARALLEL_GROUP is None, "pipeline model parallel group is already initialized"
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is None
), "pipeline model parallel group is already initialized"
global _EMBEDDING_GROUP
global _EMBEDDING_GLOBAL_RANKS
assert _EMBEDDING_GROUP is None, "embedding group is already initialized"
global _POSITION_EMBEDDING_GROUP
global _POSITION_EMBEDDING_GLOBAL_RANKS
assert (
_POSITION_EMBEDDING_GROUP is None
), "position embedding group is already initialized"
for i in range(num_pipeline_model_parallel_groups):
ranks = range(i, world_size, num_pipeline_model_parallel_groups)
group = torch.distributed.new_group(ranks)
......@@ -167,22 +214,38 @@ def initialize_model_parallel(
if len(ranks) > 1:
embedding_ranks = [ranks[0], ranks[-1]]
position_embedding_ranks = [ranks[0]]
if (
pipeline_model_parallel_split_rank_ is not None and
ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks
):
if pipeline_model_parallel_split_rank_ is not None:
if ranks[pipeline_model_parallel_split_rank_] not in embedding_ranks:
embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_], ranks[-1]]
if ranks[pipeline_model_parallel_split_rank_] not in position_embedding_ranks:
position_embedding_ranks = [ranks[0], ranks[pipeline_model_parallel_split_rank_]]
embedding_ranks = [
ranks[0],
ranks[pipeline_model_parallel_split_rank_],
ranks[-1],
]
if (
ranks[pipeline_model_parallel_split_rank_]
not in position_embedding_ranks
):
position_embedding_ranks = [
ranks[0],
ranks[pipeline_model_parallel_split_rank_],
]
else:
embedding_ranks = ranks
position_embedding_ranks = ranks
group = torch.distributed.new_group(embedding_ranks)
if rank in embedding_ranks:
_EMBEDDING_GROUP = group
if rank in ranks:
_EMBEDDING_GLOBAL_RANKS = embedding_ranks
group = torch.distributed.new_group(position_embedding_ranks)
if rank in position_embedding_ranks:
_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
def get_rank_info() -> Tuple[int, int, int]:
"""Returns a tuple of (data, tensor, pipeline, virtual pipeline)-parallel-rank for logger."""
if model_parallel_is_initialized():
......@@ -197,7 +260,11 @@ def get_rank_info() -> Tuple[int, int, int]:
def model_parallel_is_initialized():
"""Check if model and data parallel groups are initialized."""
if _TENSOR_MODEL_PARALLEL_GROUP is None or _PIPELINE_MODEL_PARALLEL_GROUP is None or _DATA_PARALLEL_GROUP is None:
if (
_TENSOR_MODEL_PARALLEL_GROUP is None
or _PIPELINE_MODEL_PARALLEL_GROUP is None
or _DATA_PARALLEL_GROUP is None
):
return False
return True
......@@ -210,13 +277,17 @@ def get_model_parallel_group():
def get_tensor_model_parallel_group():
"""Get the tensor model parallel group the caller rank belongs to."""
assert _TENSOR_MODEL_PARALLEL_GROUP is not None, "intra_layer_model parallel group is not initialized"
assert (
_TENSOR_MODEL_PARALLEL_GROUP is not None
), "intra_layer_model parallel group is not initialized"
return _TENSOR_MODEL_PARALLEL_GROUP
def get_pipeline_model_parallel_group():
"""Get the pipeline model parallel group the caller rank belongs to."""
assert _PIPELINE_MODEL_PARALLEL_GROUP is not None, "pipeline_model parallel group is not initialized"
assert (
_PIPELINE_MODEL_PARALLEL_GROUP is not None
), "pipeline_model parallel group is not initialized"
return _PIPELINE_MODEL_PARALLEL_GROUP
......@@ -232,6 +303,14 @@ def get_embedding_group():
return _EMBEDDING_GROUP
def get_position_embedding_group():
"""Get the position embedding group the caller rank belongs to."""
assert (
_POSITION_EMBEDDING_GROUP is not None
), "position embedding group is not initialized"
return _POSITION_EMBEDDING_GROUP
def is_rank_in_embedding_group(ignore_virtual=False):
"""Return true if current rank is in embedding group, False otherwise."""
rank = torch.distributed.get_rank()
......@@ -248,6 +327,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return False
def is_rank_in_position_embedding_group():
"""Return whether the current rank is in position embedding group."""
rank = torch.distributed.get_rank()
global _POSITION_EMBEDDING_GLOBAL_RANKS
return rank in _POSITION_EMBEDDING_GLOBAL_RANKS
def is_pipeline_stage_before_split(rank=None):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
......@@ -283,7 +369,9 @@ def is_pipeline_stage_at_split():
stage executes encoder block for a model with both encoder and
decoder."""
rank = get_pipeline_model_parallel_rank()
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(rank + 1)
return is_pipeline_stage_before_split(rank) and is_pipeline_stage_after_split(
rank + 1
)
def set_tensor_model_parallel_world_size(world_size):
......@@ -342,6 +430,9 @@ def get_pipeline_model_parallel_rank():
return torch.distributed.get_rank(group=get_pipeline_model_parallel_group())
# TODO (mkozuki): Add [`get_num_layers`](https://github.com/NVIDIA/Megatron-LM/blob/e156d2fea7fc5c98e645f7742eb86b643956d840/megatron/mpu/initialize.py#L321) here, maybe?
def get_pipeline_model_parallel_split_rank():
"""Return my rank for the pipeline model parallel split rank."""
global _PIPELINE_MODEL_PARALLEL_SPLIT_RANK
......@@ -368,12 +459,16 @@ def is_pipeline_first_stage(ignore_virtual=False):
def is_pipeline_last_stage(ignore_virtual=False):
"""Return True if in the last pipeline model-parallel stage, False otherwise."""
if not ignore_virtual:
virtual_pipeline_model_parallel_world_size = get_virtual_pipeline_model_parallel_world_size()
virtual_pipeline_model_parallel_world_size = (
get_virtual_pipeline_model_parallel_world_size()
)
if virtual_pipeline_model_parallel_world_size is not None and get_virtual_pipeline_model_parallel_rank() != (
virtual_pipeline_model_parallel_world_size - 1
):
return False
return get_pipeline_model_parallel_rank() == (get_pipeline_model_parallel_world_size() - 1)
return get_pipeline_model_parallel_rank() == (
get_pipeline_model_parallel_world_size() - 1
)
def get_virtual_pipeline_model_parallel_rank():
......@@ -402,26 +497,42 @@ def get_tensor_model_parallel_src_rank():
return (global_rank // local_world_size) * local_world_size
def get_data_parallel_src_rank():
"""Calculate the global rank corresponding to the first local rank in the data parallel group."""
global_rank = torch.distributed.get_rank()
data_parallel_size: int = get_data_parallel_world_size()
num_data_parallel_groups = torch.distributed.get_world_size() // data_parallel_size
return global_rank % num_data_parallel_groups
def get_pipeline_model_parallel_first_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
return _PIPELINE_GLOBAL_RANKS[0]
def get_pipeline_model_parallel_last_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
last_rank_local = get_pipeline_model_parallel_world_size() - 1
return _PIPELINE_GLOBAL_RANKS[last_rank_local]
def get_pipeline_model_parallel_next_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline + 1) % world_size]
def get_pipeline_model_parallel_prev_rank():
assert _PIPELINE_GLOBAL_RANKS is not None, "Pipeline parallel group is not initialized"
assert (
_PIPELINE_GLOBAL_RANKS is not None
), "Pipeline parallel group is not initialized"
rank_in_pipeline = get_pipeline_model_parallel_rank()
world_size = get_pipeline_model_parallel_world_size()
return _PIPELINE_GLOBAL_RANKS[(rank_in_pipeline - 1) % world_size]
......@@ -437,6 +548,9 @@ def get_data_parallel_rank():
return torch.distributed.get_rank(group=get_data_parallel_group())
# note (mkozuki): `destroy_model_parallel` voids more global variables than Megatron-LM.
# Otherwise pipeline parallel forward_backward functions test hangs possibly because
# the clean-up of the original is NOT enough.
def destroy_model_parallel():
"""Set the groups to none."""
global _MODEL_PARALLEL_GROUP
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment