Unverified Commit 265b451d authored by eqy's avatar eqy Committed by GitHub
Browse files

Do pipeline parallelism tests in double because TF32 environment variables can...

Do pipeline parallelism tests in double because TF32 environment variables can be painful to manage across test suites (#1391)

* check in

* skip interleaved with 2 GPU

* change type annotation

* address comments thanks @crcrpar @Aidyn-A
parent ab5fc48f
...@@ -154,10 +154,13 @@ def initialize_model_parallel( ...@@ -154,10 +154,13 @@ def initialize_model_parallel(
num_data_parallel_groups: int = world_size // data_parallel_size num_data_parallel_groups: int = world_size // data_parallel_size
if virtual_pipeline_model_parallel_size_ is not None: if virtual_pipeline_model_parallel_size_ is not None:
# assert pipeline_model_parallel_size_ > 2, ( # n.b. (eqy) This check was inherited from Megatron-LM, need to revisit
# "pipeline-model-parallel size should be greater than 2 with " # the root cause as we do see numerical mismatches with 2 stages and
# "interleaved schedule" # the interleaved schedule
# ) assert pipeline_model_parallel_size_ > 2, (
"pipeline-model-parallel size should be greater than 2 with "
"interleaved schedule"
)
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK global _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK
global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE global _VIRTUAL_PIPELINE_MODEL_PARALLEL_WORLD_SIZE
_VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0 _VIRTUAL_PIPELINE_MODEL_PARALLEL_RANK = 0
......
import logging import logging
import itertools import itertools
import re import re
from typing import Optional from typing import Optional, Tuple, List
import unittest import unittest
import torch import torch
...@@ -48,12 +48,12 @@ def get_init_weights_func(offset: int = 0): ...@@ -48,12 +48,12 @@ def get_init_weights_func(offset: int = 0):
return init_weights return init_weights
def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> float: def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> Tuple[torch.Tensor, List[torch.Tensor]]:
model = [] model = []
data = torch.ones(global_batch_shape, dtype=torch.float) data = torch.ones(global_batch_shape, dtype=torch.double)
for i in range(total_layers): for i in range(total_layers):
w = torch.ones((hidden_size, hidden_size)) * (i + 1.0) / weight_coeff w = torch.ones((hidden_size, hidden_size), dtype=torch.double) * (i + 1.0) / weight_coeff
b = torch.ones(hidden_size) b = torch.ones(hidden_size, dtype=torch.double)
w.requires_grad_() w.requires_grad_()
b.requires_grad_() b.requires_grad_()
...@@ -68,6 +68,22 @@ def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total ...@@ -68,6 +68,22 @@ def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total
return loss, model return loss, model
def _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size: Optional[int] = None
) -> Tuple[int, int, int]:
# TODO: revisit if we can fold this into the class for skip logic / avoid duplication
# of world size computation
world_size = torch.cuda.device_count()
tensor_model_parallel_world_size = 1
data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0)
if pipeline_model_parallel_world_size is None:
pipeline_model_parallel_world_size = world_size // (tensor_model_parallel_world_size * data_parallel_size)
else:
data_parallel_size = world_size // (tensor_model_parallel_world_size * pipeline_model_parallel_world_size)
return tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size
class PipelineParallelForwardBackwardTestBase: class PipelineParallelForwardBackwardTestBase:
GLOBAL_BATCH_SIZE = 16 GLOBAL_BATCH_SIZE = 16
...@@ -93,7 +109,7 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -93,7 +109,7 @@ class PipelineParallelForwardBackwardTestBase:
if fwd_bwd_func == _forward_backward_pipelining_with_interleaving: if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
self.assertIsNotNone(virtual_pipeline_model_parallel_size) self.assertIsNotNone(virtual_pipeline_model_parallel_size)
self.assertGreater(virtual_pipeline_model_parallel_size, 1) self.assertGreater(virtual_pipeline_model_parallel_size, 1)
dtype_options = self.dtypes or [torch.float32] + _get_autocast_dtypes() dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes()
for dtype, deallocate_pipeline_outputs in itertools.product( for dtype, deallocate_pipeline_outputs in itertools.product(
dtype_options, self.deallocate_options, dtype_options, self.deallocate_options,
...@@ -103,13 +119,10 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -103,13 +119,10 @@ class PipelineParallelForwardBackwardTestBase:
if dtype == torch.half if dtype == torch.half
else None else None
) )
tensor_model_parallel_world_size = 1
data_parallel_size = 1 + (self.world_size >= 8 and self.world_size % 2 == 0)
if pipeline_model_parallel_world_size is None: (tensor_model_parallel_world_size,
pipeline_model_parallel_world_size = self.world_size // (tensor_model_parallel_world_size * data_parallel_size) data_parallel_size,
else: pipeline_model_parallel_world_size) = _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size)
data_parallel_size = self.world_size // (tensor_model_parallel_world_size * pipeline_model_parallel_world_size)
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size, tensor_model_parallel_size_=tensor_model_parallel_world_size,
...@@ -135,7 +148,7 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -135,7 +148,7 @@ class PipelineParallelForwardBackwardTestBase:
batch = None batch = None
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
batch = (torch.ones(global_batch_shape, dtype=torch.float).cuda(), ) batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )
model = build_model( model = build_model(
testing_utils.model_provider_func, testing_utils.model_provider_func,
...@@ -148,6 +161,7 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -148,6 +161,7 @@ class PipelineParallelForwardBackwardTestBase:
offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0 offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
for idx, model_module in enumerate(model): for idx, model_module in enumerate(model):
model_module = model_module.to(dtype)
model_module.apply(get_init_weights_func(idx*offset)) model_module.apply(get_init_weights_func(idx*offset))
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
...@@ -155,7 +169,6 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -155,7 +169,6 @@ class PipelineParallelForwardBackwardTestBase:
pp_utils.update_num_microbatches(0) pp_utils.update_num_microbatches(0)
with common_cuda.tf32_off():
loss = fwd_bwd_func( loss = fwd_bwd_func(
testing_utils.fwd_step_func, testing_utils.fwd_step_func,
batch, batch,
...@@ -173,7 +186,7 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -173,7 +186,7 @@ class PipelineParallelForwardBackwardTestBase:
deallocate_pipeline_output=deallocate_pipeline_outputs, deallocate_pipeline_output=deallocate_pipeline_outputs,
) )
if dtype == torch.float32: if dtype == torch.double:
hidden_size = self.HIDDEN_SIZE hidden_size = self.HIDDEN_SIZE
microbatch_size = self.MICRO_BATCH_SIZE microbatch_size = self.MICRO_BATCH_SIZE
total_layers = pipeline_model_parallel_world_size total_layers = pipeline_model_parallel_world_size
...@@ -233,11 +246,13 @@ class PipelineParallelForwardBackwardTestBase: ...@@ -233,11 +246,13 @@ class PipelineParallelForwardBackwardTestBase:
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
) )
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Megatron-LM voodoo")
def test_pipelining_with_interleaving(self): def test_pipelining_with_interleaving(self):
self._forward_backward_test_impl( self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2 False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
) )
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Megatron-LM voodoo")
def test_pipelining_with_interleaving_inference(self): def test_pipelining_with_interleaving_inference(self):
self._forward_backward_test_impl( self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2 True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment