Unverified Commit 3fe35211 authored by eqy's avatar eqy Committed by GitHub
Browse files

Async pipeline parallel (#1373)

* initial check in

* fix

* fix test

* address some review comments and cleanup

* fix

* bookmark

* fix sync placement to come before gather

* similar fix for non-gather case

* add async bert

* update gpt minimal test

* allow selection of default pp test

* fix bert test

* cleanup

* cleanup
parent 68440264
......@@ -26,11 +26,27 @@ from apex.transformer.pipeline_parallel.utils import Shape
from apex.transformer.pipeline_parallel._timers import _Timers
class FutureTensor:
def __init__(self, tensor: torch.Tensor, waitfunc):
self.tensor = tensor
self.waitfunc = waitfunc
def get(self):
if self.waitfunc is not None:
res = self.waitfunc()
if isinstance(res, torch.Tensor):
self.tensor = res
self.waitfunc = None
return self.tensor
def _run_p2pops(
tensor_send_prev: Union[torch.Tensor, None],
tensor_send_next: Union[torch.Tensor, None],
tensor_recv_prev: Union[torch.Tensor, None],
tensor_recv_next: Union[torch.Tensor, None],
async_comm: bool = False
):
ops = []
if tensor_send_prev is not None:
......@@ -63,8 +79,18 @@ def _run_p2pops(
ops.append(recv_next_op)
if len(ops) > 0:
reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs:
req.wait()
if async_comm:
assert len(reqs) == len(ops)
tensor_send_prev_req = None if tensor_send_prev is None else reqs.pop(0)
tensor_recv_prev_req = None if tensor_recv_prev is None else reqs.pop(0)
tensor_send_next_req = None if tensor_send_next is None else reqs.pop(0)
tensor_recv_next_req = None if tensor_recv_next is None else reqs.pop(0)
return (tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req)
else:
for req in reqs:
req.wait()
return (None, None, None, None)
return (None, None, None, None)
def _communicate(
......@@ -79,7 +105,8 @@ def _communicate(
scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None,
fp32_residual_connection: bool = False,
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
async_comm: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor, None], Union[torch.Tensor, FutureTensor, None]]:
"""Base function for communication of tensors between stages.
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
......@@ -161,26 +188,73 @@ def _communicate(
tensor_send_prev = split_tensor_into_1d_equal_chunks(tensor_send_prev)
# Send tensors in both the forward and backward directions as appropriate.
_run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next)
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
tensor_send_prev_req, tensor_recv_prev_req, tensor_send_next_req, tensor_recv_next_req = _run_p2pops(tensor_send_prev, tensor_send_next, tensor_recv_prev, tensor_recv_next, async_comm=async_comm)
if async_comm:
tensor_recv_prev_waitfunc = None
tensor_recv_next_waitfunc = None
# TODO: investigate whether this is necessary for correctness (ref: https://github.com/pytorch/pytorch/issues/38642)
# see also: sync added for async_comm callbacks below in gather_recv_prev_wait and gather_recv_next_wait
if tensor_recv_prev_req is not None:
def tensor_recv_prev_wait():
tensor_recv_prev_req.wait()
torch.cuda.synchronize()
tensor_recv_prev_waitfunc = tensor_recv_prev_wait
if tensor_recv_next_req is not None:
def tensor_recv_next_wait():
tensor_recv_next_req.wait()
torch.cuda.synchronize()
tensor_recv_next_waitfunc = tensor_recv_next_wait
else:
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
# If using scatter-gather optimization, gather smaller chunks.
if not override_scatter_gather_tensors_in_pipeline and scatter_gather_tensors_in_pipeline:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
if not async_comm:
if recv_prev:
tensor_recv_prev = (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
if recv_next:
tensor_recv_next = (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
else:
def gather_recv_prev_wait():
tensor_recv_prev_req.wait()
# From @Deepak's PR https://github.com/NVIDIA/Megatron-LM/commit/27fc468964064eeb33b703c9a0b2af938d80dd14
# A sync seems to be needed before gather otherwise losses jump around e.g., in run_gpt_minimal_test
torch.cuda.synchronize()
return (
gather_split_1d_tensor(tensor_recv_prev)
.view(tensor_shape)
.requires_grad_()
)
def gather_recv_next_wait():
tensor_recv_next_req.wait()
torch.cuda.synchronize()
return (
gather_split_1d_tensor(tensor_recv_next)
.view(tensor_shape)
.requires_grad_()
)
tensor_recv_prev_waitfunc = gather_recv_prev_wait
tensor_recv_next_waitfunc = gather_recv_next_wait
if async_comm:
future_tensor_recv_prev = None
future_tensor_recv_next = None
if tensor_recv_prev is not None:
future_tensor_recv_prev = FutureTensor(tensor_recv_prev, tensor_recv_prev_waitfunc)
if tensor_recv_next is not None:
future_tensor_recv_next = FutureTensor(tensor_recv_next, tensor_recv_next_waitfunc)
return future_tensor_recv_prev, future_tensor_recv_next
return tensor_recv_prev, tensor_recv_next
......@@ -190,7 +264,8 @@ def recv_forward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from previous rank in pipeline (forward receive)."""
if parallel_state.is_pipeline_first_stage():
return None
......@@ -204,6 +279,7 @@ def recv_forward(
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("forward-recv").stop()
......@@ -215,7 +291,8 @@ def recv_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Receive tensor from next rank in pipeline (backward receive)."""
if parallel_state.is_pipeline_last_stage():
return None
......@@ -228,6 +305,7 @@ def recv_backward(
recv_next=True,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("backward-recv").stop()
......@@ -241,6 +319,7 @@ def send_forward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
) -> None:
"""Send tensor to next rank in pipeline (forward send)."""
if parallel_state.is_pipeline_last_stage():
......@@ -255,6 +334,7 @@ def send_forward(
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("forward-send").stop()
......@@ -266,6 +346,8 @@ def send_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
async_comm: bool = False,
) -> None:
"""Send tensor to previous rank in pipeline (backward send)."""
if parallel_state.is_pipeline_first_stage():
......@@ -279,6 +361,7 @@ def send_backward(
recv_next=False,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("backward-send").stop()
......@@ -290,7 +373,8 @@ def send_forward_recv_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> Union[None, torch.Tensor]:
async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with next rank in pipeline."""
if parallel_state.is_pipeline_last_stage():
return None
......@@ -303,6 +387,7 @@ def send_forward_recv_backward(
recv_next=True,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("forward-send-backward-recv").stop()
......@@ -315,7 +400,8 @@ def send_backward_recv_forward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> Union[None, torch.Tensor]:
async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor, None]:
"""Batched send and recv with previous rank in pipeline."""
if parallel_state.is_pipeline_first_stage():
return None
......@@ -328,6 +414,7 @@ def send_backward_recv_forward(
recv_next=False,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("backward-send-forward-recv").stop()
......@@ -341,7 +428,8 @@ def send_forward_recv_forward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from previous rank and send to next rank in pipeline."""
# if timers is not None:
# timers("forward-send-forward-recv").start()
......@@ -352,6 +440,7 @@ def send_forward_recv_forward(
recv_next=False,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("forward-send-forward-recv").stop()
......@@ -365,7 +454,8 @@ def send_backward_recv_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
async_comm: bool = False,
) -> Union[torch.Tensor, FutureTensor]:
"""Batched recv from next rank and send to previous rank in pipeline."""
# if timers is not None:
# timers("backward-send-backward-recv").start()
......@@ -376,6 +466,7 @@ def send_backward_recv_backward(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("backward-send-backward-recv").stop()
......@@ -391,7 +482,8 @@ def send_forward_backward_recv_forward_backward(
*,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
async_comm: bool = False,
) -> Tuple[Union[torch.Tensor, FutureTensor], Union[torch.Tensor, FutureTensor]]:
"""Batched send and recv with previous and next ranks in pipeline."""
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").start()
......@@ -402,6 +494,7 @@ def send_forward_backward_recv_forward_backward(
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=dtype,
async_comm=async_comm,
)
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
......
......@@ -6,6 +6,7 @@ from torch.autograd.variable import Variable
from apex.normalization.fused_layer_norm import FusedLayerNorm
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
......@@ -19,7 +20,7 @@ from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
Batch = Union[torch.Tensor, FutureTensor, List[Union[torch.Tensor, FutureTensor]], Tuple[Union[torch.Tensor, FutureTensor], ...]]
LossFunc = Callable[[torch.Tensor], torch.Tensor]
FwdStepFunc = Callable[
[Optional[Batch], torch.nn.Module], Tuple[torch.Tensor, LossFunc]
......@@ -288,6 +289,8 @@ def forward_step(
if unwrap_output_tensor:
input_tensor = [input_tensor]
input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor]
unwrapped_model.set_input_tensor(input_tensor)
with torch.cuda.amp.autocast(
enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16),
......@@ -349,15 +352,23 @@ def backward_step(
unwrap_input_tensor_grad = not isinstance(input_tensor, list)
if unwrap_input_tensor_grad:
input_tensor = [input_tensor]
input_tensor = [inp.get() if isinstance(inp, FutureTensor) else inp for inp in input_tensor]
for x in input_tensor:
if x is not None:
x.retain_grad()
if not isinstance(output_tensor, list):
output_tensor = [output_tensor]
output_tensor = [out.get() if isinstance(out, FutureTensor) else out for out in output_tensor]
if not isinstance(output_tensor_grad, list):
output_tensor_grad = [output_tensor_grad]
output_tensor_grad = [ogr.get() if isinstance(ogr, FutureTensor) else ogr for ogr in output_tensor_grad]
# Backward pass.
if grad_scaler is not None and output_tensor_grad[0] is None:
output_tensor[0] = grad_scaler.scale(output_tensor[0])
......
......@@ -6,6 +6,7 @@ import torch
from apex.transformer import parallel_state
from apex.transformer.enums import ModelType
from apex.transformer.pipeline_parallel import p2p_communication
from apex.transformer.pipeline_parallel.p2p_communication import FutureTensor
from apex.transformer.pipeline_parallel.utils import get_kth_microbatch
from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import get_num_microbatches
......@@ -65,13 +66,14 @@ def recv_forward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
return input_tensors
......@@ -79,13 +81,14 @@ def recv_backward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype))
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm))
return output_tensor_grads
......@@ -94,13 +97,14 @@ def send_forward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
) -> None:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype)
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
def send_backward(
......@@ -108,13 +112,14 @@ def send_backward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
async_comm: bool = False,
) -> None:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype)
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
def send_forward_recv_backward(
......@@ -122,7 +127,8 @@ def send_forward_recv_backward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
output_tensor_grads = []
......@@ -130,7 +136,7 @@ def send_forward_recv_backward(
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype)
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
......@@ -140,7 +146,8 @@ def send_backward_recv_forward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
async_comm: bool = False,
) -> List[Union[None, torch.Tensor, FutureTensor]]:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
input_tensors = []
......@@ -148,7 +155,7 @@ def send_backward_recv_forward(
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype)
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype, async_comm=async_comm)
input_tensors.append(input_tensor)
return input_tensors
......@@ -165,7 +172,8 @@ def forward_backward_pipelining_without_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
**kwawrgs,
async_comm: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
......@@ -243,7 +251,7 @@ def forward_backward_pipelining_without_interleaving(
for i in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {i} / {num_warmup_microbatches}")
_logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
cur_microbatch: Optional[torch.Tensor] = get_kth_microbatch(batch, i)
output_tensor = forward_step(
forward_step_func,
......@@ -255,7 +263,7 @@ def forward_backward_pipelining_without_interleaving(
disable_autocast,
)
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
if not forward_only:
input_tensors.append(input_tensor)
......@@ -267,7 +275,7 @@ def forward_backward_pipelining_without_interleaving(
# receive this tensor here.
if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start")
input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
input_tensor: List[Union[None, torch.Tensor, FutureTensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
###################################################################################################################
# Run 1F1B in steady state.
......@@ -289,15 +297,15 @@ def forward_backward_pipelining_without_interleaving(
)
if forward_only:
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
else:
_logger.debug("send fwd & receive bwd")
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
......@@ -320,10 +328,10 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration:
input_tensor = None
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
else:
_logger.debug("send bwd and receive fwd")
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
......@@ -335,7 +343,7 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype, async_comm=async_comm)
input_tensor_grad = backward_step(
input_tensor,
......@@ -347,6 +355,6 @@ def forward_backward_pipelining_without_interleaving(
)
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype, async_comm=async_comm)
return losses_reduced
......@@ -116,7 +116,7 @@ def fwd_step_func(batch, model):
lm_loss = torch.sum(lm_loss_.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
averaged_loss = average_losses_across_data_parallel_group([lm_loss])
if data_idx >= 1536:
assert lm_loss < 4.8
assert averaged_loss < 4.8
if not ONCE:
print("LOSS OK")
ONCE = True
......@@ -126,7 +126,7 @@ def fwd_step_func(batch, model):
def train(
model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size
model, optim, virtual_pipeline_model_parallel_size, pipeline_model_parallel_size, async_comm
):
sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size
......@@ -139,7 +139,7 @@ def train(
batch = generate_fancy_data_labels(sequence_len, batch_size)
optim.zero_grad()
forward_backward_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm,
)
optim.step()
......@@ -157,51 +157,60 @@ if __name__ == "__main__":
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
init = True
try:
args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
)
virtual_pipeline_model_parallel_size = 2
pipeline_model_parallel_size = world_size
parallel_state.initialize_model_parallel(
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model(
bert_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list)
assert len(model) == (
1
if virtual_pipeline_model_parallel_size is None
else virtual_pipeline_model_parallel_size
)
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
print(effective_length)
print(fancy_data.size(0))
train(
model,
optim,
virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_size,
)
for virtual_pipeline_model_parallel_size in (2, None):
async_comm = virtual_pipeline_model_parallel_size is None
data_idx = 0
ONCE = False
if init:
init = False
args = global_vars.get_args()
args.padded_vocab_size = 128 # needed in standalone gpt
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size,
)
else:
parallel_state.destroy_model_parallel()
parallel_state.initialize_model_parallel(
args.tensor_model_parallel_size,
args.pipeline_model_parallel_size,
virtual_pipeline_model_parallel_size,
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
tensor_parallel.random.model_parallel_cuda_manual_seed(0)
model = build_model(
bert_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=virtual_pipeline_model_parallel_size,
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list)
assert len(model) == (
1
if virtual_pipeline_model_parallel_size is None
else virtual_pipeline_model_parallel_size
)
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
print(effective_length)
print(fancy_data.size(0))
train(
model,
optim,
virtual_pipeline_model_parallel_size,
args.pipeline_model_parallel_size,
async_comm,
)
except Exception as e:
failure = str(e)
finally:
......
......@@ -92,7 +92,6 @@ def loss_func(loss_mask, output_tensor):
# Reduce loss for logging.
averaged_loss = average_losses_across_data_parallel_group([loss])
return loss, {"lm loss": averaged_loss[0]}
......@@ -104,7 +103,7 @@ def fwd_step_func(batch, model):
return output_tensor, partial(loss_func, loss_mask)
def train(model, optim, pipeline_model_parallel_size):
def train(model, optim, pipeline_model_parallel_size, async_comm):
sequence_len = global_vars.get_args().seq_length
micro_batch_size = global_vars.get_args().micro_batch_size
hidden_size = global_vars.get_args().hidden_size
......@@ -125,7 +124,7 @@ def train(model, optim, pipeline_model_parallel_size):
print("finished making batch...")
optim.zero_grad()
fwd_bwd_func(
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape
fwd_step_func, batch, model, forward_only=False, tensor_shape=tensor_shape, async_comm=async_comm
)
if torch.distributed.get_rank() == 0:
print("finished forward step")
......@@ -137,52 +136,58 @@ def train(model, optim, pipeline_model_parallel_size):
if __name__ == "__main__":
global fancy_data
global effective_length
global_vars.set_global_variables()
fancy_data = download_fancy_data()
args = global_vars.get_args()
effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
args.padded_vocab_size = 128
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size, # args.data_parallel_size,
)
world_size = torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
)
init = True
for async_comm in (False, True):
global fancy_data
global effective_length
if init:
init = False
global_vars.set_global_variables()
args = global_vars.get_args()
fancy_data = download_fancy_data()
effective_length = fancy_data.size(0) // args.seq_length
effective_length = fancy_data.size(0) - args.seq_length
initialize_distributed()
world_size = torch.distributed.get_world_size()
failure = None
args.padded_vocab_size = 128
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
setup_microbatch_calculator(
args.rank,
args.rampup_batch_size,
args.global_batch_size,
args.micro_batch_size,
args.data_parallel_size, # args.data_parallel_size,
)
world_size = torch.distributed.get_world_size()
print(args.tensor_model_parallel_size, "MODEL PARALLEL SIZE")
parallel_state.initialize_model_parallel(
tensor_model_parallel_size_=args.tensor_model_parallel_size,
pipeline_model_parallel_size_=args.pipeline_model_parallel_size,
)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
runtime = train(model, optim, args.pipeline_model_parallel_size)
pipeline_model_parallel_size = (
parallel_state.get_pipeline_model_parallel_world_size()
)
model_parallel_cuda_manual_seed(0)
model = build_model(
gpt_model_provider,
wrap_with_ddp=True,
virtual_pipeline_model_parallel_size=None,
cpu_offload=args.cpu_offload,
)
assert isinstance(model, list), model
_param_groups = _get_params_for_weight_decay_optimization(model)
optim = torch.optim.Adam(_param_groups)
runtime = train(model, optim, args.pipeline_model_parallel_size, async_comm)
parallel_state.destroy_model_parallel()
parallel_state.destroy_model_parallel()
torch.distributed.barrier()
if torch.distributed.get_rank() == 0:
print(TEST_SUCCESS_MESSAGE)
......
......@@ -70,6 +70,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
fwd_bwd_func: FwdStepFunc,
pipeline_model_parallel_world_size: Optional[int],
virtual_pipeline_model_parallel_size: Optional[int],
async_comm: bool = False,
) -> None:
for dtype, deallocate_pipeline_outputs in itertools.product(
[torch.float32] + _get_autocast_dtypes(), (True, False),
......@@ -136,6 +137,7 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
PipelineParallelForwardBackwardTest.HIDDEN_SIZE,
),
dtype=dtype,
async_comm=async_comm,
grad_scaler=grad_scaler,
deallocate_pipeline_output=deallocate_pipeline_outputs,
)
......@@ -164,16 +166,26 @@ class PipelineParallelForwardBackwardTest(DistributedTestBase):
def test_no_pipelining_inference(self):
self._forward_backward_test_impl(True, forward_backward_no_pipelining, 1, None)
def test_pipelining(self):
def test_pipelining_default(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_async(self):
self._forward_backward_test_impl(
False, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_pipelining_inference(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None
)
def test_pipelining_inference_async(self):
self._forward_backward_test_impl(
True, forward_backward_pipelining_without_interleaving, None, None, async_comm=True
)
def test_pipelining_with_interleaving(self):
self._forward_backward_test_impl(
False, _forward_backward_pipelining_with_interleaving, None, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment