Unverified Commit 3fe35211 authored by eqy's avatar eqy Committed by GitHub
Browse files

Async pipeline parallel (#1373)

* initial check in

* fix

* fix test

* address some review comments and cleanup

* fix

* bookmark

* fix sync placement to come before gather

* similar fix for non-gather case

* add async bert

* update gpt minimal test

* allow selection of default pp test

* fix bert test

* cleanup

* cleanup
parent 68440264
...@@ -26,11 +26,27 @@ from apex.transformer.pipeline_parallel.utils import Shape ...@@ -26,11 +26,27 @@ from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers from apex.transformer.pipeline_parallel._timers import _Timers
class FutureTensor:
def __init__(self, tensor: torch.Tensor, waitfunc):
self.tensor = tensor
self.waitfunc = waitfunc
def get(self):
if self.waitfunc is not None:
res = self.waitfunc()
if isinstance(res, torch.Tensor):
self.tensor = res
self.waitfunc = None
return self.tensor
def _run_p2pops( def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None], tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None], tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None], tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None], tensor_recv_next: Union[torch.Tensor, None],
async_comm: bool = False
): ):
ops = [] ops = []
if tensor_send_prev is not None: if tensor_send_prev is not None:
...@@ -63,8 +79,18 @@ def _run_p2pops( ...@@ -63,8 +79,18 @@ def _run_p2pops(
ops.append(recv_next_op) ops.append(recv_next_op)
if len(ops) > 0: if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops) reqs = torch.distributed.batch_isend_irecv(ops)
if async_comm:
assert len(reqs) == len(ops)
tensor_send_prev_req = None if tensor_send_prev is None else reqs.pop(0)
tensor_recv_prev_req = None if tensor_recv_prev is None else reqs.pop(0)
tensor_send_next_req = None if tensor_send_next is None else reqs.pop(0)
tensor_recv_next_req = None if tensor_recv_next is None else reqs.pop(0)
return (tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req)
else:
for req in reqs: for req in reqs:
req.wait() req.wait()
return (None, None, None, None)
return (None, None, None, None)
def _communicate( def _communicate(
...@@ -79,7 +105,8 @@ def _communicate( ...@@ -79,7 +105,8 @@ def _communicate(
scatter_gather_tensors_in_pipeline: bool = True, scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False, fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]: async_comm: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]:
"""Base function for communication of tensors between stages. """Base function for communication of tensors between stages.
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified, dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
...@@ -161,12 +188,30 @@ def _communicate( ...@@ -161,12 +188,30 @@ def _communicate(
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev) tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate. # Send tensors in both the forward and backward directions as appropriate.
_run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next) tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next, async_comm=async_comm)
if async_comm:
tensor_recv_prev_waitfunc = None
tensor_recv_next_waitfunc = None
# TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642)
# see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait
if tensor_recv_prev_req is not None:
def tensor_recv_prev_wait():
tensor_recv_prev_req.wait()
torch.cuda.synchronize()
tensor_recv_prev_waitfunc = tensor_recv_prev_wait
if tensor_recv_next_req is not None:
def tensor_recv_next_wait():
tensor_recv_next_req.wait()
torch.cuda.synchronize()
tensor_recv_next_waitfunc = tensor_recv_next_wait
else:
# To protect against race condition when using batch_isend_irecv(). # To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize() torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks. # If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline: if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if not async_comm:
if recv_prev: if recv_prev:
tensor_recv_prev = ( tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev) gather_split_1d_tensor(tensor_recv_prev)
...@@ -180,6 +225,35 @@ def _communicate( ...@@ -180,6 +225,35 @@ def _communicate(
.view(tensor_shape) .view(tensor_shape)
.requires_grad_() .requires_grad_()
) )
else:
def gather_recv_prev_wait():
tensor_recv_prev_req.wait()
# From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
# A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
torch.cuda.synchronize()
return (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
def gather_recv_next_wait():
tensor_recv_next_req.wait()
torch.cuda.synchronize()
return (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
tensor_recv_prev_waitfunc = gather_recv_prev_wait
tensor_recv_next_waitfunc = gather_recv_next_wait
if async_comm:
future_tensor_recv_prev = None
future_tensor_recv_next = None
if tensor_recv_prev is not None:
future_tensor_recv_prev = FutureTensor(tensor_recv_prev, tensor_recv_prev_waitfunc)
if tensor_recv_next is not None:
future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc)
return future_tensor_recv_prev, future_tensor_recv_next
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
...@@ -190,7 +264,8 @@ def recv_forward( ...@@ -190,7 +264,8 @@ def recv_forward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from previous rank in pipeline (forward receive).""" """Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
return None return None
...@@ -204,6 +279,7 @@ def recv_forward( ...@@ -204,6 +279,7 @@ def recv_forward(
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("forward-recv").stop() # timers("forward-recv").stop()
...@@ -215,7 +291,8 @@ def recv_backward( ...@@ -215,7 +291,8 @@ def recv_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from next rank in pipeline (backward receive).""" """Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
return None return None
...@@ -228,6 +305,7 @@ def recv_backward( ...@@ -228,6 +305,7 @@ def recv_backward(
recv_next=True, recv_next=True,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("backward-recv").stop() # timers("backward-recv").stop()
...@@ -241,6 +319,7 @@ def send_forward( ...@@ -241,6 +319,7 @@ def send_forward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
async_comm: bool = False,
) -> None: ) -> None:
"""Send tensor to next rank in pipeline (forward send).""" """Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
...@@ -255,6 +334,7 @@ def send_forward( ...@@ -255,6 +334,7 @@ def send_forward(
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("forward-send").stop() # timers("forward-send").stop()
...@@ -266,6 +346,8 @@ def send_backward( ...@@ -266,6 +346,8 @@ def send_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
async_comm: bool = False,
) -> None: ) -> None:
"""Send tensor to previous rank in pipeline (backward send).""" """Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
...@@ -279,6 +361,7 @@ def send_backward( ...@@ -279,6 +361,7 @@ def send_backward(
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("backward-send").stop() # timers("backward-send").stop()
...@@ -290,7 +373,8 @@ def send_forward_recv_backward( ...@@ -290,7 +373,8 @@ def send_forward_recv_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> Union[None, torch.Tensor]: async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with next rank in pipeline.""" """Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage(): if parallel_state.is_pipeline_last_stage():
return None return None
...@@ -303,6 +387,7 @@ def send_forward_recv_backward( ...@@ -303,6 +387,7 @@ def send_forward_recv_backward(
recv_next=True, recv_next=True,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("forward-send-backward-recv").stop() # timers("forward-send-backward-recv").stop()
...@@ -315,7 +400,8 @@ def send_backward_recv_forward( ...@@ -315,7 +400,8 @@ def send_backward_recv_forward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> Union[None, torch.Tensor]: async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with previous rank in pipeline.""" """Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage(): if parallel_state.is_pipeline_first_stage():
return None return None
...@@ -328,6 +414,7 @@ def send_backward_recv_forward( ...@@ -328,6 +414,7 @@ def send_backward_recv_forward(
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("backward-send-forward-recv").stop() # timers("backward-send-forward-recv").stop()
...@@ -341,7 +428,8 @@ def send_forward_recv_forward( ...@@ -341,7 +428,8 @@ def send_forward_recv_forward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from previous rank and send to next rank in pipeline.""" """Batched recv from previous rank and send to next rank in pipeline."""
# if timers is not None: # if timers is not None:
# timers("forward-send-forward-recv").start() # timers("forward-send-forward-recv").start()
...@@ -352,6 +440,7 @@ def send_forward_recv_forward( ...@@ -352,6 +440,7 @@ def send_forward_recv_forward(
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("forward-send-forward-recv").stop() # timers("forward-send-forward-recv").stop()
...@@ -365,7 +454,8 @@ def send_backward_recv_backward( ...@@ -365,7 +454,8 @@ def send_backward_recv_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
# if timers is not None: # if timers is not None:
# timers("backward-send-backward-recv").start() # timers("backward-send-backward-recv").start()
...@@ -376,6 +466,7 @@ def send_backward_recv_backward( ...@@ -376,6 +466,7 @@ def send_backward_recv_backward(
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("backward-send-backward-recv").stop() # timers("backward-send-backward-recv").stop()
...@@ -391,7 +482,8 @@ def send_forward_backward_recv_forward_backward( ...@@ -391,7 +482,8 @@ def send_forward_backward_recv_forward_backward(
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> Tuple[torch.Tensor, torch.Tensor]: async_comm: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]:
"""Batched send and recv with previous and next ranks in pipeline.""" """Batched send and recv with previous and next ranks in pipeline."""
# if timers is not None: # if timers is not None:
# timers("forward-backward-send-forward-backward-recv").start() # timers("forward-backward-send-forward-backward-recv").start()
...@@ -402,6 +494,7 @@ def send_forward_backward_recv_forward_backward( ...@@ -402,6 +494,7 @@ def send_forward_backward_recv_forward_backward(
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=dtype, dtype_=dtype,
async_comm=async_comm,
) )
# if timers is not None: # if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop() # timers("forward-backward-send-forward-backward-recv").stop()
......
...@@ -6,6 +6,7 @@ from torch.autograd.variable import Variable ...@@ -6,6 +6,7 @@ from torch.autograd.variable import Variable
from apex.normalization.fused_layer_norm import FusedLayerNorm from apex.normalization.fused_layer_norm import FusedLayerNorm
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model from apex.transformer.pipeline_parallel.utils import unwrap_model
...@@ -19,7 +20,7 @@ from apex.transformer.log_util import get_transformer_logger ...@@ -19,7 +20,7 @@ from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__) _logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]] Batch = Union[torch.Tensor, FutureTensor, List[Union[torch.Tensor, FutureTensor]], Tuple[Union[torch.Tensor, FutureTensor], ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor] LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[ FwdStepFunc = Callable[
[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc] [Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]
...@@ -288,6 +289,8 @@ def forward_step( ...@@ -288,6 +289,8 @@ def forward_step(
if unwrap_output_tensor: if unwrap_output_tensor:
input_tensor = [input_tensor] input_tensor = [input_tensor]
input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor]
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
with torch.cuda.amp.autocast( with torch.cuda.amp.autocast(
enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16), enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16),
...@@ -349,15 +352,23 @@ def backward_step( ...@@ -349,15 +352,23 @@ def backward_step(
unwrap_input_tensor_grad = not isinstance(input_tensor, list) unwrap_input_tensor_grad = not isinstance(input_tensor, list)
if unwrap_input_tensor_grad: if unwrap_input_tensor_grad:
input_tensor = [input_tensor] input_tensor = [input_tensor]
input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor]
for x in input_tensor: for x in input_tensor:
if x is not None: if x is not None:
x.retain_grad() x.retain_grad()
if not isinstance(output_tensor, list): if not isinstance(output_tensor, list):
output_tensor = [output_tensor] output_tensor = [output_tensor]
output_tensor = [out.get() if isinstance(out, FutureTensor) else out for out in output_tensor]
if not isinstance(output_tensor_grad, list): if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad] output_tensor_grad = [output_tensor_grad]
output_tensor_grad = [ogr.get() if isinstance(ogr, FutureTensor) else ogr for ogr in output_tensor_grad]
# Backward pass. # Backward pass.
if grad_scaler is not None and output_tensor_grad[0] is None: if grad_scaler is not None and output_tensor_grad[0] is None:
output_tensor[0] = grad_scaler.scale(output_tensor[0]) output_tensor[0] = grad_scaler.scale(output_tensor[0])
......
...@@ -6,6 +6,7 @@ import torch ...@@ -6,6 +6,7 @@ import torch
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.enums import ModelType from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel import p2p_communication from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import listify_model from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches from apex.transformer.pipeline_parallel.utils import get_num_microbatches
...@@ -65,13 +66,14 @@ def recv_forward( ...@@ -65,13 +66,14 @@ def recv_forward(
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]: async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
input_tensors = [] input_tensors = []
for tensor_shape in tensor_shapes: for tensor_shape in tensor_shapes:
if tensor_shape is None: if tensor_shape is None:
input_tensors.append(None) input_tensors.append(None)
else: else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype)) input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
return input_tensors return input_tensors
...@@ -79,13 +81,14 @@ def recv_backward( ...@@ -79,13 +81,14 @@ def recv_backward(
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]: async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
output_tensor_grads = [] output_tensor_grads = []
for tensor_shape in tensor_shapes: for tensor_shape in tensor_shapes:
if tensor_shape is None: if tensor_shape is None:
output_tensor_grads.append(None) output_tensor_grads.append(None)
else: else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype)) output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
return output_tensor_grads return output_tensor_grads
...@@ -94,13 +97,14 @@ def send_forward( ...@@ -94,13 +97,14 @@ def send_forward(
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
) -> None: ) -> None:
if not isinstance(output_tensors, list): if not isinstance(output_tensors, list):
output_tensors = [output_tensors] output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None: if tensor_shape is None:
continue continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype) p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
def send_backward( def send_backward(
...@@ -108,13 +112,14 @@ def send_backward( ...@@ -108,13 +112,14 @@ def send_backward(
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
) -> None: ) -> None:
if not isinstance(input_tensor_grads, list): if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads] input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None: if tensor_shape is None:
continue continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype) p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
def send_forward_recv_backward( def send_forward_recv_backward(
...@@ -122,7 +127,8 @@ def send_forward_recv_backward( ...@@ -122,7 +127,8 @@ def send_forward_recv_backward(
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]: async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(output_tensors, list): if not isinstance(output_tensors, list):
output_tensors = [output_tensors] output_tensors = [output_tensors]
output_tensor_grads = [] output_tensor_grads = []
...@@ -130,7 +136,7 @@ def send_forward_recv_backward( ...@@ -130,7 +136,7 @@ def send_forward_recv_backward(
if tensor_shape is None: if tensor_shape is None:
output_tensor_grads.append(None) output_tensor_grads.append(None)
continue continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype) output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
output_tensor_grads.append(output_tensor_grad) output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads return output_tensor_grads
...@@ -140,7 +146,8 @@ def send_backward_recv_forward( ...@@ -140,7 +146,8 @@ def send_backward_recv_forward(
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*, *,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]: async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(input_tensor_grads, list): if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads] input_tensor_grads = [input_tensor_grads]
input_tensors = [] input_tensors = []
...@@ -148,7 +155,7 @@ def send_backward_recv_forward( ...@@ -148,7 +155,7 @@ def send_backward_recv_forward(
if tensor_shape is None: if tensor_shape is None:
input_tensors.append(None) input_tensors.append(None)
continue continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype) input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
return input_tensors return input_tensors
...@@ -165,7 +172,8 @@ def forward_backward_pipelining_without_interleaving( ...@@ -165,7 +172,8 @@ def forward_backward_pipelining_without_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None, grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False, disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False, deallocate_pipeline_outputs: bool = False,
**kwawrgs, async_comm: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages. """Run non-interleaved 1F1B schedule, with communication between pipeline stages.
...@@ -243,7 +251,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -243,7 +251,7 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches): for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}") _logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd") _logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i) cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
output_tensor = forward_step( output_tensor = forward_step(
forward_step_func, forward_step_func,
...@@ -255,7 +263,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -255,7 +263,7 @@ def forward_backward_pipelining_without_interleaving(
disable_autocast, disable_autocast,
) )
_logger.debug("send fwd") _logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -267,7 +275,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -267,7 +275,7 @@ def forward_backward_pipelining_without_interleaving(
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start") _logger.debug("recv_forward before steady state start")
input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) input_tensor: List[Union[None, torch.Tensor, FutureTensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
################################################################################################################### ###################################################################################################################
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -289,15 +297,15 @@ def forward_backward_pipelining_without_interleaving( ...@@ -289,15 +297,15 @@ def forward_backward_pipelining_without_interleaving(
) )
if forward_only: if forward_only:
_logger.debug("send fwd") _logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
if not last_iteration: if not last_iteration:
_logger.debug("receive fwd (last iteration)") _logger.debug("receive fwd (last iteration)")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
else: else:
_logger.debug("send fwd & receive bwd") _logger.debug("send fwd & receive bwd")
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -320,10 +328,10 @@ def forward_backward_pipelining_without_interleaving( ...@@ -320,10 +328,10 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
_logger.debug("send bwd") _logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
else: else:
_logger.debug("send bwd and receive fwd") _logger.debug("send bwd and receive fwd")
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
################################################################################################################### ###################################################################################################################
# Run cooldown backward passes. # Run cooldown backward passes.
################################################################################################################### ###################################################################################################################
...@@ -335,7 +343,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -335,7 +343,7 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd") _logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype) output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor_grad = backward_step( input_tensor_grad = backward_step(
input_tensor, input_tensor,
...@@ -347,6 +355,6 @@ def forward_backward_pipelining_without_interleaving( ...@@ -347,6 +355,6 @@ def forward_backward_pipelining_without_interleaving(
) )
_logger.debug("send bwd") _logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
return losses_reduced return losses_reduced
...@@ -116,7 +116,7 @@ def fwd_step_func(batch, model): ...@@ -116,7 +116,7 @@ def fwd_step_func(batch, model):
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum() lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
averaged_loss = average_losses_across_data_parallel_group([lm_loss]) averaged_loss = average_losses_across_data_parallel_group([lm_loss])
if data_idx >= 1536: if data_idx >= 1536:
assert lm_loss < 4.8 assert averaged_loss < 4.8
if not ONCE: if not ONCE:
print("LOSS OK") print("LOSS OK")
ONCE = True ONCE = True
...@@ -126,7 +126,7 @@ def fwd_step_func(batch, model): ...@@ -126,7 +126,7 @@ def fwd_step_func(batch, model):
def train( def train(
model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, async_comm
): ):
sequence_len = global_vars.get_args().seq_length sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size micro_batch_size = global_vars.get_args().micro_batch_size
...@@ -139,7 +139,7 @@ def train( ...@@ -139,7 +139,7 @@ def train(
batch = generate_fancy_data_labels(sequence_len, batch_size) batch = generate_fancy_data_labels(sequence_len, batch_size)
optim.zero_grad() optim.zero_grad()
forward_backward_func( forward_backward_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm,
) )
optim.step() optim.step()
...@@ -157,7 +157,14 @@ if __name__ == "__main__": ...@@ -157,7 +157,14 @@ if __name__ == "__main__":
initialize_distributed() initialize_distributed()
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
failure = None failure = None
init = True
try: try:
for virtual_pipeline_model_parallel_size in (2, None):
async_comm = virtual_pipeline_model_parallel_size is None
data_idx = 0
ONCE = False
if init:
init = False
args = global_vars.get_args() args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt args.padded_vocab_size = 128 # needed in standalone gpt
batch_size = args.global_batch_size batch_size = args.global_batch_size
...@@ -169,8 +176,8 @@ if __name__ == "__main__": ...@@ -169,8 +176,8 @@ if __name__ == "__main__":
args.micro_batch_size, args.micro_batch_size,
args.data_parallel_size, args.data_parallel_size,
) )
virtual_pipeline_model_parallel_size = 2 else:
pipeline_model_parallel_size = world_size parallel_state.destroy_model_parallel()
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
args.tensor_model_parallel_size, args.tensor_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
...@@ -179,6 +186,7 @@ if __name__ == "__main__": ...@@ -179,6 +186,7 @@ if __name__ == "__main__":
pipeline_model_parallel_size = ( pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size() parallel_state.get_pipeline_model_parallel_world_size()
) )
tensor_parallel.random.model_parallel_cuda_manual_seed(0) tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model( model = build_model(
bert_model_provider, bert_model_provider,
...@@ -201,6 +209,7 @@ if __name__ == "__main__": ...@@ -201,6 +209,7 @@ if __name__ == "__main__":
optim, optim,
virtual_pipeline_model_parallel_size, virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_size, args.pipeline_model_parallel_size,
async_comm,
) )
except Exception as e: except Exception as e:
failure = str(e) failure = str(e)
......
...@@ -92,7 +92,6 @@ def loss_func(loss_mask, output_tensor): ...@@ -92,7 +92,6 @@ def loss_func(loss_mask, output_tensor):
# Reduce loss for logging. # Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss]) averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]} return loss, {"lm loss": averaged_loss[0]}
...@@ -104,7 +103,7 @@ def fwd_step_func(batch, model): ...@@ -104,7 +103,7 @@ def fwd_step_func(batch, model):
return output_tensor, partial(loss_func, loss_mask) return output_tensor, partial(loss_func, loss_mask)
def train(model, optim, pipeline_model_parallel_size): def train(model, optim, pipeline_model_parallel_size, async_comm):
sequence_len = global_vars.get_args().seq_length sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size hidden_size = global_vars.get_args().hidden_size
...@@ -125,7 +124,7 @@ def train(model, optim, pipeline_model_parallel_size): ...@@ -125,7 +124,7 @@ def train(model, optim, pipeline_model_parallel_size):
print("finished making batch...") print("finished making batch...")
optim.zero_grad() optim.zero_grad()
fwd_bwd_func( fwd_bwd_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm
) )
if torch.distributed.get_rank() == 0: if torch.distributed.get_rank() == 0:
print("finished forward step") print("finished forward step")
...@@ -137,13 +136,16 @@ def train(model, optim, pipeline_model_parallel_size): ...@@ -137,13 +136,16 @@ def train(model, optim, pipeline_model_parallel_size):
if __name__ == "__main__": if __name__ == "__main__":
init = True
for async_comm in (False, True):
global fancy_data global fancy_data
global effective_length global effective_length
if init:
init = False
global_vars.set_global_variables() global_vars.set_global_variables()
fancy_data = download_fancy_data()
args = global_vars.get_args() args = global_vars.get_args()
fancy_data = download_fancy_data()
effective_length = fancy_data.size(0) // args.seq_length effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length effective_length = fancy_data.size(0) - args.seq_length
...@@ -162,6 +164,9 @@ if __name__ == "__main__": ...@@ -162,6 +164,9 @@ if __name__ == "__main__":
args.data_parallel_size, # args.data_parallel_size, args.data_parallel_size, # args.data_parallel_size,
) )
world_size = torch.distributed.get_world_size() world_size = torch.distributed.get_world_size()
print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")
parallel_state.initialize_model_parallel( parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size, tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size, pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
...@@ -180,7 +185,7 @@ if __name__ == "__main__": ...@@ -180,7 +185,7 @@ if __name__ == "__main__":
assert isinstance(model, list), model assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model) _param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups) optim = torch.optim.Adam(_param_groups)
runtime = train(model, optim, args.pipeline_model_parallel_size) runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm)
parallel_state.destroy_model_parallel() parallel_state.destroy_model_parallel()
torch.distributed.barrier() torch.distributed.barrier()
......
...@@ -70,6 +70,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -70,6 +70,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
fwd_bwd_func: FwdStepFunc, fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int], pipeline_model_parallel_world_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int], virtual_pipeline_model_parallel_size: Optional[int],
async_comm: bool = False,
) -> None: ) -> None:
for dtype, deallocate_pipeline_outputs in itertools.product( for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False), [torch.float32] + _get_autocast_dtypes(), (True, False),
...@@ -136,6 +137,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -136,6 +137,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
PipelineParallelForwardBackwardTest.HIDDEN_SIZE, PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
), ),
dtype=dtype, dtype=dtype,
async_comm=async_comm,
grad_scaler=grad_scaler, grad_scaler=grad_scaler,
deallocate_pipeline_output=deallocate_pipeline_outputs, deallocate_pipeline_output=deallocate_pipeline_outputs,
) )
...@@ -164,16 +166,26 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase): ...@@ -164,16 +166,26 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
def test_no_pipelining_inference(self): def test_no_pipelining_inference(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None) self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_pipelining(self): def test_pipelining_default(self):
self._forward_backward_test_impl( self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None False, forward_backward_pipelining_without_interleaving, None, None
) )
def test_pipelining_async(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_pipelining_inference(self): def test_pipelining_inference(self):
self._forward_backward_test_impl( self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None True, forward_backward_pipelining_without_interleaving, None, None
) )
def test_pipelining_inference_async(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_pipelining_with_interleaving(self): def test_pipelining_with_interleaving(self):
self._forward_backward_test_impl( self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, None False, _forward_backward_pipelining_with_interleaving, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment