Unverified Commit fb21698e authored by Tim Moon's avatar Tim Moon Committed by GitHub
Browse files

Fix bug when initializing model-parallel process groups for GPT-3 (#1435)

* Hack to enable training GPT-3

Seems to fix bug from #1416

* Add test to initialize model-parallelism for decoder-only Transformers

Namely GPT-3.
parent e57d9e79
......@@ -44,7 +44,7 @@ _EMBEDDING_GROUP = None
_POSITION_EMBEDDING_GROUP = None
# Relative position embedding group.
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP = None
......@@ -294,14 +294,16 @@ def initialize_model_parallel(
if rank in ranks:
_POSITION_EMBEDDING_GLOBAL_RANKS = position_embedding_ranks
group = torch.distributed.new_group(encoder_relative_position_embedding_ranks)
if encoder_relative_position_embedding_ranks:
group = torch.distributed.new_group(encoder_relative_position_embedding_ranks)
if rank in encoder_relative_position_embedding_ranks:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
_ENCODER_RELATIVE_POSITION_EMBEDDING_GLOBAL_RANKS = \
encoder_relative_position_embedding_ranks
group = torch.distributed.new_group(decoder_relative_position_embedding_ranks)
if decoder_relative_position_embedding_ranks:
group = torch.distributed.new_group(decoder_relative_position_embedding_ranks)
if rank in decoder_relative_position_embedding_ranks:
_DECODER_RELATIVE_POSITION_EMBEDDING_GROUP = group
if rank in ranks:
......
......@@ -132,6 +132,50 @@ class ParallelStateTestBase:
parallel_state.destroy_model_parallel()
def test_initialize_model_parallel_decoder_only(self) -> None:
"""Initialize model parallelism for decoder-only Transformers like GPT-3"""
self.assertFalse(parallel_state.model_parallel_is_initialized())
for tensor_model_parallel_world_size in range(1, self.world_size + 1):
with self.subTest(
tensor_model_parallel_world_size=tensor_model_parallel_world_size
):
if self.world_size % tensor_model_parallel_world_size:
continue
pipeline_model_parallel_world_size = (
self.world_size // tensor_model_parallel_world_size
)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
pipeline_model_parallel_split_rank_=0,
)
self.assertEqual(
tensor_model_parallel_world_size,
parallel_state.get_tensor_model_parallel_world_size(),
)
expected_tensor_model_parallel_rank = calc_expected_tensor_model_paralell_rank(
self.rank, tensor_model_parallel_world_size
)
self.assertEqual(
expected_tensor_model_parallel_rank,
parallel_state.get_tensor_model_parallel_rank(),
)
expected_tensor_model_parallel_src_rank = (
self.rank // tensor_model_parallel_world_size
) * tensor_model_parallel_world_size
self.assertEqual(
expected_tensor_model_parallel_src_rank,
parallel_state.get_tensor_model_parallel_src_rank(),
)
parallel_state.destroy_model_parallel()
self.assertFalse(parallel_state.model_parallel_is_initialized())
class NcclParallelStateTest(ParallelStateTestBase, NcclDistributedTestBase): pass
class UccParallelStateTest(ParallelStateTestBase, UccDistributedTestBase): pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment