Unverified Commit a29a698f authored by Aidyn-A's avatar Aidyn-A Committed by GitHub
Browse files

[transformer] UCC async test (#1417)

* add test

* update batch sizes

* update batch sizes

* small updates

* delete comment

* add async comm

* add sync if needed

* update tests

* remove redundant imports

* code cleanup

* minor updates

* update dtype for comparison

* fix dtypes

* fix typo

* modify sizes and use common_utils.find_free_port

* fix typo and use double precision

* revert some changes, create test for profiling on L1

* remove redundant line

* revert UCC_TLS and add sync to fwd_bwd

* code clean up

* code clean up

* modify BERT test

* add comment
parent 809043f5
......@@ -53,35 +53,47 @@ def _run_p2pops(
async_comm: bool = False
):
ops = []
p2p_group = parallel_state.get_pipeline_model_parallel_group()
default_group = parallel_state.get_model_parallel_group()
need_to_sync = p2p_group.name() != default_group.name()
if tensor_send_prev is not None:
send_prev_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
op=torch.distributed.isend,
tensor=tensor_send_prev,
peer=parallel_state.get_pipeline_model_parallel_prev_rank(),
group=p2p_group,
)
ops.append(send_prev_op)
if tensor_recv_prev is not None:
recv_prev_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_prev,
parallel_state.get_pipeline_model_parallel_prev_rank(),
op=torch.distributed.irecv,
tensor=tensor_recv_prev,
peer=parallel_state.get_pipeline_model_parallel_prev_rank(),
group=p2p_group,
)
ops.append(recv_prev_op)
if tensor_send_next is not None:
send_next_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor_send_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
op=torch.distributed.isend,
tensor=tensor_send_next,
peer=parallel_state.get_pipeline_model_parallel_next_rank(),
group=p2p_group,
)
ops.append(send_next_op)
if tensor_recv_next is not None:
recv_next_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv_next,
parallel_state.get_pipeline_model_parallel_next_rank(),
op=torch.distributed.irecv,
tensor=tensor_recv_next,
peer=parallel_state.get_pipeline_model_parallel_next_rank(),
group=p2p_group,
)
ops.append(recv_next_op)
if len(ops) > 0:
if need_to_sync:
torch.cuda.synchronize()
reqs = torch.distributed.batch_isend_irecv(ops)
if async_comm:
assert len(reqs) == len(ops)
......
......@@ -34,6 +34,7 @@ def _forward_backward_pipelining_with_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
async_comm: bool = False,
sequence_parallel_enabled: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
......@@ -218,6 +219,7 @@ def _forward_backward_pipelining_with_interleaving(
p2p_communication.recv_forward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
......@@ -265,6 +267,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
......@@ -275,6 +278,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
......@@ -359,6 +363,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
free_output_tensor(output_tensor, deallocate_pipeline_outputs)
......@@ -380,6 +385,7 @@ def _forward_backward_pipelining_with_interleaving(
p2p_communication.recv_backward(
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
......@@ -401,6 +407,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
async_comm=async_comm,
sequence_parallel_enabled=sequence_parallel_enabled,
)
)
......
......@@ -115,7 +115,7 @@ class UccDistributedTestBase(DistributedTestBase):
self._has_ucx_tls = "UCX_TLS" in os.environ
if not self._has_ucx_tls:
os.environ["UCX_TLS"] = "tcp,cuda_copy"
os.environ["UCX_TLS"] = "tcp,cuda"
print('os.environ[\"UCX_TLS\"] = {}'.format(os.environ["UCX_TLS"]))
def tearDown(self) -> None:
......
......@@ -186,7 +186,12 @@ if __name__ == "__main__":
failure = None
init = True
try:
for virtual_pipeline_model_parallel_size in (2, None):
virtual_pipeline_model_parallel_sizes = (None, 2,)
if HAS_TORCH_UCC:
# Deliberately skipping test with interleaved schedule for BERT model.
# It deadlocks on hybrid UCC/NCCL backend.
virtual_pipeline_model_parallel_sizes = (None,)
for virtual_pipeline_model_parallel_size in virtual_pipeline_model_parallel_sizes:
args = global_vars.get_args()
async_comm = not args.sequence_parallel and virtual_pipeline_model_parallel_size is None
data_idx = 0
......
......@@ -49,12 +49,19 @@ def get_init_weights_func(offset: int = 0):
return init_weights
def get_dtype_for_comparison():
if(torch.cuda.get_device_capability() >= (8, 0)):
return torch.float64
return torch.float32
def get_target_loss_and_model(global_batch_shape: tuple, hidden_size: int, total_layers: int) -> Tuple[torch.Tensor, List[torch.Tensor]]:
model = []
data = torch.ones(global_batch_shape, dtype=torch.double)
dtype = get_dtype_for_comparison()
data = torch.ones(global_batch_shape, dtype=dtype)
for i in range(total_layers):
w = torch.ones((hidden_size, hidden_size), dtype=torch.double) * (i + 1.0) / weight_coeff
b = torch.ones(hidden_size, dtype=torch.double)
w = torch.ones((hidden_size, hidden_size), dtype=dtype) * (i + 1.0) / weight_coeff
b = torch.ones(hidden_size, dtype=dtype)
w.requires_grad_()
b.requires_grad_()
......@@ -187,7 +194,8 @@ class PipelineParallelForwardBackwardTestBase:
deallocate_pipeline_output=deallocate_pipeline_outputs,
)
if dtype == torch.double:
if dtype == get_dtype_for_comparison():
torch.cuda.synchronize()
hidden_size = self.HIDDEN_SIZE
microbatch_size = self.MICRO_BATCH_SIZE
total_layers = pipeline_model_parallel_world_size
......@@ -221,44 +229,56 @@ class PipelineParallelForwardBackwardTestBase:
parallel_state.destroy_model_parallel()
def test_no_pipelining(self):
def test_learning_no_pipelining(self):
self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None)
def test_no_pipelining_inference(self):
def test_inference_no_pipelining(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_pipelining_without_interleaving(self):
def test_learning_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_async(self):
def test_inference_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
True, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_without_interleaving_inference(self):
def test_learning_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_pipelining_inference_async(self):
def test_inference_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Megatron-LM voodoo")
def test_pipelining_with_interleaving(self):
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Megatron-LM voodoo")
def test_pipelining_with_interleaving_inference(self):
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
class NcclPipelineParallelForwardBackwardTest(NcclDistributedTestBase, PipelineParallelForwardBackwardTestBase):
......@@ -283,10 +303,10 @@ class NcclPipelineParallelForwardBackwardTest(NcclDistributedTestBase, PipelineP
):
self._run_hybrid_distributed_backend(forward_only)
def test_pipelining_without_interleaving_ucc_for_p2p(self):
def test_learning_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(False)
def test_pipelining_without_interleaving_inference_ucc_for_p2p(self):
def test_inference_pipelining_without_interleaving_ucc_for_p2p(self):
self._test_hybrid_backends(True)
......
import os
import logging
import itertools
from typing import Optional, Tuple, List
import unittest
import torch
from torch.testing._internal import common_utils
from torch.testing._internal import common_cuda
from torch.testing._internal import common_distributed
from apex._autocast_utils import _get_autocast_dtypes
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import utils as pp_utils
from apex.transformer.pipeline_parallel.schedules.common import (
FwdStepFunc,
build_model,
_get_params_for_weight_decay_optimization,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_no_pipelining import (
forward_backward_no_pipelining,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_with_interleaving import (
_forward_backward_pipelining_with_interleaving,
)
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import (
forward_backward_pipelining_without_interleaving,
)
from apex.transformer.testing.distributed_test_base import UccDistributedTestBase
from apex.transformer.testing import commons as testing_utils
logging.getLogger("torch").setLevel(logging.WARNING)
logging.getLogger("apex").setLevel(logging.WARNING)
def _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size: Optional[int] = None
) -> Tuple[int, int, int]:
# TODO: revisit if we can fold this into the class for skip logic / avoid duplication
# of world size computation
world_size = torch.cuda.device_count()
tensor_model_parallel_world_size = 1
data_parallel_size = 1 + (world_size >= 8 and world_size % 2 == 0)
if pipeline_model_parallel_world_size is None:
pipeline_model_parallel_world_size = world_size // (tensor_model_parallel_world_size * data_parallel_size)
else:
data_parallel_size = world_size // (tensor_model_parallel_world_size * pipeline_model_parallel_world_size)
return tensor_model_parallel_world_size, data_parallel_size, pipeline_model_parallel_world_size
class UccPipelineParallelForwardBackwardProf(UccDistributedTestBase):
# The purpose of this class is to test and confirm asynchronous communication via profiling.
# Having that in mind, it is safe to skip all the numerical checks.
# For unit testing with numerical checks please refer to `tests/L0/run_transformer/test_pipeline_parallel_fwd_bwd.py`.
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.GLOBAL_BATCH_SIZE = 1024
self.MICRO_BATCH_SIZE = 64
self.HIDDEN_SIZE = 256
self.NUM_FWD_BWD_ITERATIONS = 4
self.deallocate_options = (False,)
self.dtypes = (torch.float32,)
@property
def world_size(self) -> int:
return min(torch.cuda.device_count(), 8)
def _forward_backward_test_impl(
self,
forward_only: bool,
fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int],
async_comm: bool = False,
*,
default_backend: Optional[str] = None,
p2p_backend: Optional[str] = None,
) -> None:
if fwd_bwd_func == _forward_backward_pipelining_with_interleaving:
self.assertIsNotNone(virtual_pipeline_model_parallel_size)
self.assertGreater(virtual_pipeline_model_parallel_size, 1)
dtype_options = self.dtypes or [torch.float32, torch.double] + _get_autocast_dtypes()
for dtype, deallocate_pipeline_outputs in itertools.product(
dtype_options, self.deallocate_options,
):
grad_scaler = (
torch.cuda.amp.GradScaler(init_scale=4.0)
if dtype == torch.half
else None
)
(tensor_model_parallel_world_size,
data_parallel_size,
pipeline_model_parallel_world_size) = _get_default_world_sizes_model_parallel_world_size(pipeline_model_parallel_world_size)
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=tensor_model_parallel_world_size,
pipeline_model_parallel_size_=pipeline_model_parallel_world_size,
virtual_pipeline_model_parallel_size_=virtual_pipeline_model_parallel_size,
default_backend=default_backend,
p2p_backend=p2p_backend,
)
pp_utils._reconfigure_microbatch_calculator(
rank=parallel_state.get_tensor_model_parallel_rank(),
rampup_batch_size=None,
global_batch_size=self.GLOBAL_BATCH_SIZE,
micro_batch_size=self.MICRO_BATCH_SIZE,
data_parallel_size=parallel_state.get_data_parallel_world_size(),
)
global_batch_shape = (
self.GLOBAL_BATCH_SIZE
// parallel_state.get_data_parallel_world_size(),
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
)
batch = None
if parallel_state.is_pipeline_first_stage():
batch = (torch.ones(global_batch_shape, dtype=dtype).cuda(), )
model = build_model(
testing_utils.model_provider_func,
# Use DDP only when it's better to have
wrap_with_ddp=data_parallel_size > 1,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
hidden_size=self.HIDDEN_SIZE,
)
offset = pipeline_model_parallel_world_size if virtual_pipeline_model_parallel_size is not None else 0
for idx, model_module in enumerate(model):
model_module = model_module.to(dtype)
_param_groups = _get_params_for_weight_decay_optimization(model)
optimizer = torch.optim.Adam(_param_groups, lr=1e-3)
pp_utils.update_num_microbatches(0)
for _ in range(self.NUM_FWD_BWD_ITERATIONS):
loss = fwd_bwd_func(
testing_utils.fwd_step_func,
batch,
model,
forward_only=forward_only,
# `tensor_shape` is the shape of micro batch.
tensor_shape=(
self.MICRO_BATCH_SIZE,
self.HIDDEN_SIZE,
self.HIDDEN_SIZE,
),
dtype=dtype,
async_comm=async_comm,
grad_scaler=grad_scaler,
deallocate_pipeline_output=deallocate_pipeline_outputs,
)
parallel_state.destroy_model_parallel()
def test_learning_no_pipelining(self):
self._forward_backward_test_impl(False, forward_backward_no_pipelining, 1, None)
def test_inference_no_pipelining(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_learning_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
def test_inference_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
)
def test_learning_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_inference_async_pipelining_without_interleaving(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_learning_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
@unittest.skipUnless(_get_default_world_sizes_model_parallel_world_size()[-1] > 2, "Interleaved schedule requires pipeline_model_parallel_world_size > 2")
def test_inference_async_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
True, _forward_backward_pipelining_with_interleaving, None, virtual_pipeline_model_parallel_size=2, async_comm=True
)
if __name__ == "__main__":
os.environ["UCC_TLS"] = "ucp,cuda"
common_distributed.TIMEOUT_DEFAULT = 500
common_utils.run_tests()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment