Unverified Commit 79c01877 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

transformer: Allows for custom sync context in no pipelining forward backward function (#1281)

* add kwarg of `custom_sync_context_handler`

* add kwargs to ignore custom_sync_context_handler which mistakenly passed to fwd/bwd funcs
parent 0da60e10
......@@ -37,6 +37,7 @@ def forward_backward_no_pipelining(
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
custom_sync_context_handler = None,
**kwargs,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
......@@ -55,7 +56,10 @@ def forward_backward_no_pipelining(
forward_only:
grad_scaler:
dtype:
disable_autocast
disable_autocast: Turn off `enabled` flag of `torch.cuda.amp.autocast` if :obj:`True`.
Should be used when your forward and loss computation is in the autocast context to
avoid unnecesarily nest autocast context.
custom_sync_context_handler:
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
......@@ -68,9 +72,12 @@ def forward_backward_no_pipelining(
model = model[0]
model_type = get_model_type(model)
context_handler = placeholder_handler
if isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
if custom_sync_context_handler is not None:
context_handler = custom_sync_context_handler
elif isinstance(model, torch.nn.parallel.distributed.DistributedDataParallel):
context_handler = model.no_sync
else:
context_handler = placeholder_handler
losses_reduced = []
input_tensor, output_tensor_grad = None, None
......
......@@ -33,6 +33,7 @@ def _forward_backward_pipelining_with_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
**kwargs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
......@@ -60,6 +61,8 @@ def _forward_backward_pipelining_with_interleaving(
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......
......@@ -164,6 +164,7 @@ def forward_backward_pipelining_without_interleaving(
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
deallocate_pipeline_outputs: bool = False,
**kwawrgs,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
......@@ -185,6 +186,10 @@ def forward_backward_pipelining_without_interleaving(
tensor_shape: Shape of tensor. Required for P2P communication.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
deallocate_pipeline_outputs: If :obj:`True`, free the data of the output tensor of
each pipeline stage. Experimental.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment