Unverified Commit b88c507e authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Add an argument of `dtype` to forward_backward functions to specify the dtype...

Add an argument of `dtype` to forward_backward functions to specify the dtype used in p2p comm (#1249)

* let users sepcify dtype for p2p comm taking the possibility of O2 style AMP into account

* add `dtype` argument to forward_backward functions

* fix

* better message

* add docstring of dtype

* add a link to dtype logic of p2p comm
parent e8473822
......@@ -16,11 +16,9 @@
from functools import reduce
import operator
from typing import Union, Optional, Tuple
import warnings
import torch
from apex._autocast_utils import _get_current_dtype
from apex.transformer import parallel_state
from apex.transformer.utils import split_tensor_into_1d_equal_chunks
from apex.transformer.utils import gather_split_1d_tensor
......@@ -76,7 +74,7 @@ def _communicate(
recv_next: bool,
tensor_shape: Optional[Shape] = None,
override_scatter_gather_tensors_in_pipeline: bool = False,
dtype_: torch.dtype = torch.float,
dtype_: Optional[torch.dtype] = None,
*,
scatter_gather_tensors_in_pipeline: bool = True,
params_dtype: Optional[torch.dtype] = None,
......@@ -84,6 +82,12 @@ def _communicate(
) -> Tuple[Union[torch.Tensor, None], Union[torch.Tensor, None]]:
"""Base function for communication of tensors between stages.
dtype logic: If none of ``dtype_``, ``params_dtype``, ``fp32_residual_connection`` is specified,
torch.float32 is used.
See https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/arguments.py#L145-L159
for the details of arguments of ``dtype_``, ``params_dtype``, ``fp32_residual_connection``.
Args:
tensor_send_next: tensor to send to next rank (no tensor sent if set to None).
tensor_send_prev: tensor to send to prev rank (no tensor sent if set to None).
......@@ -118,20 +122,20 @@ def _communicate(
else:
tensor_chunk_shape = tensor_shape
# NOTE(mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# The dtype logic below is copied from NVIDIA/Megatron-LM repo:
# https://github.com/NVIDIA/Megatron-LM/blob/d41696840ed0a7edb7e0499eb82a48ae112d9bb3/megatron/p2p_communication.py#L74-L81
# NOTE (mkozuki): Currently NeMo is implementing APEX AMP O2 style using PyTorch. In O2 style, forcing p2p comm to
# use FP32 will be a perf killer so that I decided to reanimate `dtype_` argument with the default value of `None`.
# NOTE (mkozuki): In PyTorch AMP, i.e. `torch.cuda.amp.autocast` context, activation tensors can be either FP32,
# FP16, or BF16 and there's no way to tell the dtypes of tensors on different devices in general.
# It might be possible if we restrict model architecture.
# dtype = params_dtype or torch.float
# if fp32_residual_connection:
# dtype = torch.float
# if dtype_ is not None:
# dtype = dtype_
# requires_grad = False
if dtype_ != torch.float32 or params_dtype is not None:
if torch.distributed.get_rank() == 0:
warnings.warn("Tensor P2P communications are executed in FP32")
dtype = torch.float32
dtype = params_dtype or torch.float
if fp32_residual_connection:
dtype = torch.float
requires_grad = True
if dtype_ is not None:
dtype = dtype_
requires_grad = False
if recv_prev:
tensor_recv_prev = torch.empty(
......@@ -199,7 +203,7 @@ def recv_forward(
recv_next=False,
tensor_shape=tensor_shape,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("forward-recv").stop()
......@@ -223,7 +227,7 @@ def recv_backward(
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("backward-recv").stop()
......@@ -250,7 +254,7 @@ def send_forward(
recv_next=False,
override_scatter_gather_tensors_in_pipeline=override_scatter_gather_tensors_in_pipeline,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("forward-send").stop()
......@@ -274,7 +278,7 @@ def send_backward(
recv_prev=False,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("backward-send").stop()
......@@ -298,7 +302,7 @@ def send_forward_recv_backward(
recv_prev=False,
recv_next=True,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("forward-send-backward-recv").stop()
......@@ -323,7 +327,7 @@ def send_backward_recv_forward(
recv_prev=True,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("backward-send-forward-recv").stop()
......@@ -347,7 +351,7 @@ def send_forward_recv_forward(
recv_prev=recv_prev,
recv_next=False,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("forward-send-forward-recv").stop()
......@@ -359,7 +363,7 @@ def send_backward_recv_backward(
recv_next: bool,
tensor_shape: Shape,
*,
dtype: torch.dtype = torch.float,
dtype: Optional[torch.dtype] = None,
timers: _Timers = None,
) -> torch.Tensor:
"""Batched recv from next rank and send to previous rank in pipeline."""
......@@ -371,7 +375,7 @@ def send_backward_recv_backward(
recv_prev=False,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("backward-send-backward-recv").stop()
......@@ -397,7 +401,7 @@ def send_forward_backward_recv_forward_backward(
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype_=_get_current_dtype(dtype),
dtype_=dtype,
)
# if timers is not None:
# timers("forward-backward-send-forward-backward-recv").stop()
......
......@@ -28,6 +28,7 @@ def _forward_backward_pipelining_with_interleaving(
*,
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
dtype: Optional[torch.dtype] = None,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
......@@ -51,6 +52,8 @@ def _forward_backward_pipelining_with_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -163,7 +166,7 @@ def _forward_backward_pipelining_with_interleaving(
# Run warmup forward passes.
###################################################################################################################
parallel_state.set_virtual_pipeline_model_parallel_rank(0)
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
input_tensors[0].append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
_logger.info("Warmup phase")
for k in range(num_warmup_microbatches):
_logger.debug(f"warmup iter: {k} / {num_warmup_microbatches}")
......@@ -201,12 +204,13 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
)
output_tensor_grads[num_model_chunks - 1].append(output_tensor_grad)
else:
_logger.debug("send fwd and receive fwd")
input_tensor = p2p_communication.send_forward_recv_forward(
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape)
output_tensor, recv_prev=recv_prev, tensor_shape=tensor_shape, dtype=dtype)
input_tensors[next_forward_model_chunk_id].append(input_tensor)
###################################################################################################################
......@@ -281,6 +285,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_prev=recv_prev,
recv_next=recv_next,
tensor_shape=tensor_shape,
dtype=dtype,
)
# Put input_tensor and output_tensor_grad in data structures in the
......@@ -296,7 +301,7 @@ def _forward_backward_pipelining_with_interleaving(
_logger.info("Cooldown phase")
if not forward_only:
if all_warmup_microbatches:
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
output_tensor_grads[num_model_chunks - 1].append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype))
for k in range(num_microbatches_remaining, num_microbatches):
_logger.debug(f"cooldown iter {k} in range({num_microbatches_remaining}, {num_microbatches})")
input_tensor_grad = backward_step_helper(k)
......@@ -309,7 +314,7 @@ def _forward_backward_pipelining_with_interleaving(
recv_next = False
output_tensor_grads[next_backward_model_chunk_id].append(
p2p_communication.send_backward_recv_backward(
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape)
input_tensor_grad, recv_next=recv_next, tensor_shape=tensor_shape, dtype=dtype)
)
return losses_reduced
......@@ -59,51 +59,67 @@ def get_tensor_shapes(
return tensor_shapes
def recv_forward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]:
def recv_forward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
input_tensors = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
input_tensors.append(None)
else:
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape))
input_tensors.append(p2p_communication.recv_forward(tensor_shape=tensor_shape, dtype=dtype))
return input_tensors
def recv_backward(tensor_shapes: List[Union[None, List[int]]],) -> List[Union[None, torch.Tensor]]:
def recv_backward(
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
output_tensor_grads = []
for tensor_shape in tensor_shapes:
if tensor_shape is None:
output_tensor_grads.append(None)
else:
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape))
output_tensor_grads.append(p2p_communication.recv_backward(tensor_shape=tensor_shape, dtype=dtype))
return output_tensor_grads
def send_forward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]],
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> None:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
for (output_tensor, tensor_shape) in zip(output_tensors, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape)
p2p_communication.send_forward(output_tensor, tensor_shape=tensor_shape, dtype=dtype)
def send_backward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> None:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
for (input_tensor_grad, tensor_shape) in zip(input_tensor_grads, tensor_shapes):
if tensor_shape is None:
continue
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape)
p2p_communication.send_backward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype)
def send_forward_recv_backward(
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]], tensor_shapes: List[Union[None, List[int]]],
output_tensors: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
if not isinstance(output_tensors, list):
output_tensors = [output_tensors]
......@@ -112,7 +128,7 @@ def send_forward_recv_backward(
if tensor_shape is None:
output_tensor_grads.append(None)
continue
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape)
output_tensor_grad = p2p_communication.send_forward_recv_backward(output_tensor, tensor_shape=tensor_shape, dtype=dtype)
output_tensor_grads.append(output_tensor_grad)
return output_tensor_grads
......@@ -120,6 +136,8 @@ def send_forward_recv_backward(
def send_backward_recv_forward(
input_tensor_grads: Union[torch.Tensor, List[Union[None, torch.Tensor]]],
tensor_shapes: List[Union[None, List[int]]],
*,
dtype: Optional[torch.dtype] = None,
) -> List[Union[None, torch.Tensor]]:
if not isinstance(input_tensor_grads, list):
input_tensor_grads = [input_tensor_grads]
......@@ -128,7 +146,7 @@ def send_backward_recv_forward(
if tensor_shape is None:
input_tensors.append(None)
continue
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape)
input_tensor = p2p_communication.send_backward_recv_forward(input_tensor_grad, tensor_shape=tensor_shape, dtype=dtype)
input_tensors.append(input_tensor)
return input_tensors
......@@ -141,6 +159,7 @@ def forward_backward_pipelining_without_interleaving(
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
decoder_sequence_length: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
......@@ -160,6 +179,8 @@ def forward_backward_pipelining_without_interleaving(
Keyword args:
forward_only:
tensor_shape: Shape of tensor. Required for P2P communication.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -210,7 +231,7 @@ def forward_backward_pipelining_without_interleaving(
cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes)
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
if not forward_only:
input_tensors.append(input_tensor)
......@@ -221,7 +242,7 @@ def forward_backward_pipelining_without_interleaving(
# receive this tensor here.
if num_microbatches_remaining > 0:
_logger.debug("recv_forward before steady state start")
input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes)
input_tensor: List[Union[None, torch.Tensor]] = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
###################################################################################################################
# Run 1F1B in steady state.
......@@ -237,15 +258,15 @@ def forward_backward_pipelining_without_interleaving(
)
if forward_only:
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes)
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
if not last_iteration:
_logger.debug("receive fwd (last iteration)")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes)
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
else:
_logger.debug("send fwd & receive bwd")
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes)
output_tensor_grad = send_forward_recv_backward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
# Add input_tensor and output_tensor to end of list.
input_tensors.append(input_tensor)
......@@ -260,10 +281,10 @@ def forward_backward_pipelining_without_interleaving(
if last_iteration:
input_tensor = None
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
else:
_logger.debug("send bwd and receive fwd")
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
input_tensor = send_backward_recv_forward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
###################################################################################################################
# Run cooldown backward passes.
###################################################################################################################
......@@ -275,11 +296,11 @@ def forward_backward_pipelining_without_interleaving(
output_tensor = output_tensors.pop(0)
_logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes)
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes)
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
return losses_reduced
from typing import Optional
import warnings
import torch
......@@ -42,6 +42,7 @@ def forward_backward_func_template(
forward_backward_func,
pipeline_model_parallel_size: int,
forward_only: bool,
enable_autocast: bool,
) -> None:
print_separator(f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}")
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
......@@ -91,8 +92,10 @@ def forward_backward_func_template(
tensor_shape[0] = micro_batch_size
update_num_microbatches(0)
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape)
dtype = torch.half if enable_autocast else None
with torch.cuda.amp.autocast():
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype)
if not forward_only:
for m in model:
......@@ -118,28 +121,36 @@ if __name__ == "__main__":
for forward_only in (True, False):
for name, forward_backward_func in fwd_bwd_functions.items():
n_tests += 1
# TODO (mkozuki): Test with data parallel size > 1.
pipeline_model_parallel_size = world_size
try:
forward_backward_func_template(
args,
name,
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
if name == "interleaving" and torch.cuda.device_count() <= 2:
warnings.warn(
f"There's only {torch.cuda.device_count()} gpus therefore skipping {name} "
"while interleaved scheduled pipeline parallel requires >2 gpus."
)
except Exception as e:
failures.append(
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
f"and forward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
continue
for enable_autocast in (True, False):
n_tests += 1
# TODO (mkozuki): Test with data parallel size > 1.
pipeline_model_parallel_size = world_size
try:
forward_backward_func_template(
args,
name,
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
enable_autocast=enable_autocast,
)
except Exception as e:
failures.append(
f"\t# {name} failed with pipeline size: {pipeline_model_parallel_size} "
f"and forward_only: {forward_only}\n"
f"pipeline rank: {parallel_state.get_pipeline_model_parallel_rank()}, "
f"virtual pipeline rank: {parallel_state.get_virtual_pipeline_model_parallel_rank()}\n"
f"{str(e)}"
)
print(failures[-1])
finally:
parallel_state.destroy_model_parallel()
else:
print_separator(f"{name} works")
print_separator("TEST RESULT")
......
......@@ -10,11 +10,11 @@ DENY_TEST = [
]
MULTIGPU_TEST = [
"pipeline_parallel_test",
"dynamic_batchsize_test",
]
SEVERALGPU_TEST = [
"bert_minimal_test",
"gpt_minimal_test",
"dynamic_batchsize_test",
]
def get_multigpu_launch_option(min_gpu):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment