Unverified Commit b88c507e authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Add an argument of `dtype` to forward_backward functions to specify the dtype...

Add an argument of `dtype` to forward_backward functions to specify the dtype used in p2p comm (#1249)

* let users sepcify dtype for p2p comm taking the possibility of O2 style AMP into account

* add `dtype` argument to forward_backward functions

* fix

* better message

* add docstring of dtype

* add a link to dtype logic of p2p comm
parent e8473822
...@@ -16,11 +16,9 @@ ...@@ -16,11 +16,9 @@
from functools import reduce from functools import reduce
import operator import operator
from typing import Union, Optional, Tuple from typing import Union, Optional, Tuple
import warnings
import torch import torch
from apex._autocast_utils import _get_current_dtype
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.utils import split_tensor_into_1d_equal_chunks from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor from apex.transformer.utils import gather_split_1d_tensor
...@@ -76,7 +74,7 @@ def _communicate( ...@@ -76,7 +74,7 @@ def _communicate(
recv_next: bool, recv_next: bool,
tensor_shape: Optional[Shape] = None, tensor_shape: Optional[Shape] = None,
override_scatter_gather_tensors_in_pipeline: bool = False, override_scatter_gather_tensors_in_pipeline: bool = False,
dtype_: torch.dtype = torch.float, dtype_: Optional[torch.dtype] = None,
*, *,
scatter_gather_tensors_in_pipeline: bool = True, scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None, params_dtype: Optional[torch.dtype] = None,
...@@ -84,6 +82,12 @@ def _communicate( ...@@ -84,6 +82,12 @@ def _communicate(
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]: ) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""Base function for communication of tensors between stages. """Base function for communication of tensors between stages.
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
torch.float32 is used.
See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159
for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``.
Args: Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None). tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None). tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
...@@ -118,20 +122,20 @@ def _communicate( ...@@ -118,20 +122,20 @@ def _communicate(
else: else:
tensor_chunk_shape = tensor_shape tensor_chunk_shape = tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32, # The dtype logic below is copied from NVIDIA/Megatron-LM repo:
# https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
# NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to
# use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`.
# NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general. # FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture. # It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float dtype = params_dtype or torch.float
# if fp32_residual_connection: if fp32_residual_connection:
# dtype = torch.float dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if dtype_ != torch.float32 or params_dtype is not None:
if torch.distributed.get_rank() == 0:
warnings.warn("Tensor P2P communications are executed in FP32")
dtype = torch.float32
requires_grad = True requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev: if recv_prev:
tensor_recv_prev = torch.empty( tensor_recv_prev = torch.empty(
...@@ -199,7 +203,7 @@ def recv_forward( ...@@ -199,7 +203,7 @@ def recv_forward(
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("forward-recv").stop() # timers("forward-recv").stop()
...@@ -223,7 +227,7 @@ def recv_backward( ...@@ -223,7 +227,7 @@ def recv_backward(
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("backward-recv").stop() # timers("backward-recv").stop()
...@@ -250,7 +254,7 @@ def send_forward( ...@@ -250,7 +254,7 @@ def send_forward(
recv_next=False, recv_next=False,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline, override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("forward-send").stop() # timers("forward-send").stop()
...@@ -274,7 +278,7 @@ def send_backward( ...@@ -274,7 +278,7 @@ def send_backward(
recv_prev=False, recv_prev=False,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("backward-send").stop() # timers("backward-send").stop()
...@@ -298,7 +302,7 @@ def send_forward_recv_backward( ...@@ -298,7 +302,7 @@ def send_forward_recv_backward(
recv_prev=False, recv_prev=False,
recv_next=True, recv_next=True,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("forward-send-backward-recv").stop() # timers("forward-send-backward-recv").stop()
...@@ -323,7 +327,7 @@ def send_backward_recv_forward( ...@@ -323,7 +327,7 @@ def send_backward_recv_forward(
recv_prev=True, recv_prev=True,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("backward-send-forward-recv").stop() # timers("backward-send-forward-recv").stop()
...@@ -347,7 +351,7 @@ def send_forward_recv_forward( ...@@ -347,7 +351,7 @@ def send_forward_recv_forward(
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=False, recv_next=False,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("forward-send-forward-recv").stop() # timers("forward-send-forward-recv").stop()
...@@ -359,7 +363,7 @@ def send_backward_recv_backward( ...@@ -359,7 +363,7 @@ def send_backward_recv_backward(
recv_next: bool, recv_next: bool,
tensor_shape: Shape, tensor_shape: Shape,
*, *,
dtype: torch.dtype = torch.float, dtype: Optional[torch.dtype] = None,
timers: _Timers = None, timers: _Timers = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline.""" """Batched recv from next rank and send to previous rank in pipeline."""
...@@ -371,7 +375,7 @@ def send_backward_recv_backward( ...@@ -371,7 +375,7 @@ def send_backward_recv_backward(
recv_prev=False, recv_prev=False,
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("backward-send-backward-recv").stop() # timers("backward-send-backward-recv").stop()
...@@ -397,7 +401,7 @@ def send_forward_backward_recv_forward_backward( ...@@ -397,7 +401,7 @@ def send_forward_backward_recv_forward_backward(
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype), dtype_=dtype,
) )
# if timers is not None: # if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop() # timers("forward-backward-send-forward-backward-recv").stop()
......
...@@ -28,6 +28,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -28,6 +28,7 @@ def _forward_backward_pipelining_with_interleaving(
*, *,
forward_only: bool, forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None, tensor_shape: Optional[Union[List[int], torch.Size]] = None,
dtype: Optional[torch.dtype] = None,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed. """Run interleaved 1F1B schedule with communication between pipeline stages as needed.
...@@ -51,6 +52,8 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -51,6 +52,8 @@ def _forward_backward_pipelining_with_interleaving(
Keyword args: Keyword args:
forward_only: forward_only:
tensor_shape: Shape of tensor. tensor_shape: Shape of tensor.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
Returns: Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise. a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
...@@ -163,7 +166,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -163,7 +166,7 @@ def _forward_backward_pipelining_with_interleaving(
# Run warmup forward passes. # Run warmup forward passes.
################################################################################################################### ###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0) parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape)) input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
_logger.info("Warmup phase") _logger.info("Warmup phase")
for k in range(num_warmup_microbatches): for k in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}") _logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
...@@ -201,12 +204,13 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -201,12 +204,13 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype=dtype,
) )
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad) output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else: else:
_logger.debug("send fwd and receive fwd") _logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward( input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape) output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype)
input_tensors[next_forward_model_chunk_id].append(input_tensor) input_tensors[next_forward_model_chunk_id].append(input_tensor)
################################################################################################################### ###################################################################################################################
...@@ -281,6 +285,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -281,6 +285,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev, recv_prev=recv_prev,
recv_next=recv_next, recv_next=recv_next,
tensor_shape=tensor_shape, tensor_shape=tensor_shape,
dtype=dtype,
) )
# Put input_tensor and output_tensor_grad in data structures in the # Put input_tensor and output_tensor_grad in data structures in the
...@@ -296,7 +301,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -296,7 +301,7 @@ def _forward_backward_pipelining_with_interleaving(
_logger.info("Cooldown phase") _logger.info("Cooldown phase")
if not forward_only: if not forward_only:
if all_warmup_microbatches: if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape)) output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype))
for k in range(num_microbatches_remaining, num_microbatches): for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})") _logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})")
input_tensor_grad = backward_step_helper(k) input_tensor_grad = backward_step_helper(k)
...@@ -309,7 +314,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -309,7 +314,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next = False recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append( output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward( p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape) input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype)
) )
return losses_reduced return losses_reduced
...@@ -59,51 +59,67 @@ def get_tensor_shapes( ...@@ -59,51 +59,67 @@ def get_tensor_shapes(
return tensor_shapes return tensor_shapes
def recv_forward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]: def recv_forward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
input_tensors = [] input_tensors = []
for tensor_shape in tensor_shapes: for tensor_shape in tensor_shapes:
if tensor_shape is None: if tensor_shape is None:
input_tensors.append(None) input_tensors.append(None)
else: else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape)) input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
return input_tensors return input_tensors
def recv_backward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]: def recv_backward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
output_tensor_grads = [] output_tensor_grads = []
for tensor_shape in tensor_shapes: for tensor_shape in tensor_shapes:
if tensor_shape is None: if tensor_shape is None:
output_tensor_grads.append(None) output_tensor_grads.append(None)
else: else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape)) output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype))
return output_tensor_grads return output_tensor_grads
def send_forward( def send_forward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]], output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> None: ) -> None:
if not isinstance(output_tensors, list): if not isinstance(output_tensors, list):
output_tensors = [output_tensors] output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes): for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None: if tensor_shape is None:
continue continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape) p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype)
def send_backward( def send_backward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]], input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> None: ) -> None:
if not isinstance(input_tensor_grads, list): if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads] input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes): for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None: if tensor_shape is None:
continue continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape) p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype)
def send_forward_recv_backward( def send_forward_recv_backward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]], output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]: ) -> List[Union[None, torch.Tensor]]:
if not isinstance(output_tensors, list): if not isinstance(output_tensors, list):
output_tensors = [output_tensors] output_tensors = [output_tensors]
...@@ -112,7 +128,7 @@ def send_forward_recv_backward( ...@@ -112,7 +128,7 @@ def send_forward_recv_backward(
if tensor_shape is None: if tensor_shape is None:
output_tensor_grads.append(None) output_tensor_grads.append(None)
continue continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape) output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype)
output_tensor_grads.append(output_tensor_grad) output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads return output_tensor_grads
...@@ -120,6 +136,8 @@ def send_forward_recv_backward( ...@@ -120,6 +136,8 @@ def send_forward_recv_backward(
def send_backward_recv_forward( def send_backward_recv_forward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]], input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]], tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]: ) -> List[Union[None, torch.Tensor]]:
if not isinstance(input_tensor_grads, list): if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads] input_tensor_grads = [input_tensor_grads]
...@@ -128,7 +146,7 @@ def send_backward_recv_forward( ...@@ -128,7 +146,7 @@ def send_backward_recv_forward(
if tensor_shape is None: if tensor_shape is None:
input_tensors.append(None) input_tensors.append(None)
continue continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape) input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype)
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
return input_tensors return input_tensors
...@@ -141,6 +159,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -141,6 +159,7 @@ def forward_backward_pipelining_without_interleaving(
forward_only: bool, forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None, tensor_shape: Optional[Union[List[int], torch.Size]] = None,
decoder_sequence_length: Optional[int] = None, decoder_sequence_length: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages. """Run non-interleaved 1F1B schedule, with communication between pipeline stages.
...@@ -160,6 +179,8 @@ def forward_backward_pipelining_without_interleaving( ...@@ -160,6 +179,8 @@ def forward_backward_pipelining_without_interleaving(
Keyword args: Keyword args:
forward_only: forward_only:
tensor_shape: Shape of tensor. Required for P2P communication. tensor_shape: Shape of tensor. Required for P2P communication.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
Returns: Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise. a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
...@@ -210,7 +231,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -210,7 +231,7 @@ def forward_backward_pipelining_without_interleaving(
cur_microbatch = get_kth_microbatch(batch, i) cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd") _logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes) send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
if not forward_only: if not forward_only:
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -221,7 +242,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -221,7 +242,7 @@ def forward_backward_pipelining_without_interleaving(
# receive this tensor here. # receive this tensor here.
if num_microbatches_remaining > 0: if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start") _logger.debug("recv_forward before steady state start")
input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes) input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
################################################################################################################### ###################################################################################################################
# Run 1F1B in steady state. # Run 1F1B in steady state.
...@@ -237,15 +258,15 @@ def forward_backward_pipelining_without_interleaving( ...@@ -237,15 +258,15 @@ def forward_backward_pipelining_without_interleaving(
) )
if forward_only: if forward_only:
_logger.debug("send fwd") _logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes) send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
if not last_iteration: if not last_iteration:
_logger.debug("receive fwd (last iteration)") _logger.debug("receive fwd (last iteration)")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
else: else:
_logger.debug("send fwd & receive bwd") _logger.debug("send fwd & receive bwd")
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes) output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
# Add input_tensor and output_tensor to end of list. # Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor) input_tensors.append(input_tensor)
...@@ -260,10 +281,10 @@ def forward_backward_pipelining_without_interleaving( ...@@ -260,10 +281,10 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
_logger.debug("send bwd") _logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
else: else:
_logger.debug("send bwd and receive fwd") _logger.debug("send bwd and receive fwd")
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes) input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
################################################################################################################### ###################################################################################################################
# Run cooldown backward passes. # Run cooldown backward passes.
################################################################################################################### ###################################################################################################################
...@@ -275,11 +296,11 @@ def forward_backward_pipelining_without_interleaving( ...@@ -275,11 +296,11 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd") _logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes) output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
_logger.debug("send bwd") _logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
return losses_reduced return losses_reduced
from typing import Optional import warnings
import torch import torch
...@@ -42,6 +42,7 @@ def forward_backward_func_template( ...@@ -42,6 +42,7 @@ def forward_backward_func_template(
forward_backward_func, forward_backward_func,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
forward_only: bool, forward_only: bool,
enable_autocast: bool,
) -> None: ) -> None:
print_separator(f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}") print_separator(f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}")
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
...@@ -91,8 +92,10 @@ def forward_backward_func_template( ...@@ -91,8 +92,10 @@ def forward_backward_func_template(
tensor_shape[0] = micro_batch_size tensor_shape[0] = micro_batch_size
update_num_microbatches(0) update_num_microbatches(0)
dtype = torch.half if enable_autocast else None
with torch.cuda.amp.autocast():
forward_backward_func( forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape) fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype)
if not forward_only: if not forward_only:
for m in model: for m in model:
...@@ -118,6 +121,13 @@ if __name__ == "__main__": ...@@ -118,6 +121,13 @@ if __name__ == "__main__":
for forward_only in (True, False): for forward_only in (True, False):
for name, forward_backward_func in fwd_bwd_functions.items(): for name, forward_backward_func in fwd_bwd_functions.items():
if name == "interleaving" and torch.cuda.device_count() <= 2:
warnings.warn(
f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} "
"while interleaved scheduled pipeline parallel requires >2 gpus."
)
continue
for enable_autocast in (True, False):
n_tests += 1 n_tests += 1
# TODO (mkozuki): Test with data parallel size > 1. # TODO (mkozuki): Test with data parallel size > 1.
pipeline_model_parallel_size = world_size pipeline_model_parallel_size = world_size
...@@ -128,6 +138,7 @@ if __name__ == "__main__": ...@@ -128,6 +138,7 @@ if __name__ == "__main__":
forward_backward_func, forward_backward_func,
pipeline_model_parallel_size, pipeline_model_parallel_size,
forward_only, forward_only,
enable_autocast=enable_autocast,
) )
except Exception as e: except Exception as e:
failures.append( failures.append(
......
...@@ -10,11 +10,11 @@ DENY_TEST = [ ...@@ -10,11 +10,11 @@ DENY_TEST = [
] ]
MULTIGPU_TEST = [ MULTIGPU_TEST = [
"pipeline_parallel_test", "pipeline_parallel_test",
"dynamic_batchsize_test",
] ]
SEVERALGPU_TEST = [ SEVERALGPU_TEST = [
"bert_minimal_test", "bert_minimal_test",
"gpt_minimal_test", "gpt_minimal_test",
"dynamic_batchsize_test",
] ]
def get_multigpu_launch_option(min_gpu): def get_multigpu_launch_option(min_gpu):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment