Commit e7e99721 authored by Jared Casper's avatar Jared Casper
Browse files

Check for pipeline_parallel > 2 when using interleaving.

parent cbf780d0
...@@ -87,7 +87,8 @@ def initialize_model_parallel( ...@@ -87,7 +87,8 @@ def initialize_model_parallel(
if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0: if world_size % (tensor_model_parallel_size * pipeline_model_parallel_size) != 0:
raise RuntimeError( raise RuntimeError(
f"world_size ({world_size}) is not divisible by tensor_model_parallel_size ({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})" f"world_size ({world_size}) is not divisible by tensor_model_parallel_size "
f"({tensor_model_parallel_size}) x pipeline_model_parallel_size ({pipeline_model_parallel_size})"
) )
data_parallel_size: int = world_size // (tensor_model_parallel_size * data_parallel_size: int = world_size // (tensor_model_parallel_size *
...@@ -98,6 +99,9 @@ def initialize_model_parallel( ...@@ -98,6 +99,9 @@ def initialize_model_parallel(
num_data_parallel_groups: int = world_size // data_parallel_size num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size is not None: if virtual_pipeline_model_parallel_size is not None:
if not pipeline_model_parallel_size_ > 2:
raise RuntimeError("pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule")
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment