Unverified Commit 2a4ab177 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Grad scaler (#1277)

* add keyword argument of `grad_scaler`

* update test

* pass dtype to fwd_step_func

* add log

* calc loss in autocast as per https://pytorch.org/docs/stable/amp.html#autocasting

* add keyword argument of `grad_scaler`

* update test

* pass dtype to fwd_step_func

* add log

* calc loss in autocast as per https://pytorch.org/docs/stable/amp.html#autocasting

* option to turn off autocast inside forward_step function

As there's some users who activate `autocast` outside fwd/bwd functions.

* add missing arg of disable_autocast

* reorder args of no pipeline
parent 45cd1001
......@@ -10,6 +10,10 @@ from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
......@@ -147,6 +151,8 @@ def forward_step(
model: torch.nn.Module,
input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor],
dtype: torch.dtype,
disable_autocast: bool = False,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""Forward step for passed-in model.
......@@ -161,6 +167,8 @@ def forward_step(
model: unwrappable model
input_tensor:
losses_reduced:
dtype:
disable_autocast:
Returns:
output_tensor
......@@ -177,12 +185,16 @@ def forward_step(
input_tensor = [input_tensor]
unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model)
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
with torch.cuda.amp.autocast(
enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16),
dtype=dtype,
):
output_tensor, loss_func = forward_step_func(batch, model)
if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
# timers("forward-compute").stop()
# If T5 model (or other model with encoder and decoder)
......@@ -200,6 +212,8 @@ def backward_step(
output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor],
model_type: ModelType,
*,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor.
......@@ -234,6 +248,8 @@ def backward_step(
output_tensor_grad = [output_tensor_grad]
# Backward pass.
if grad_scaler is not None and output_tensor_grad[0] is None:
output_tensor[0] = grad_scaler.scale(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor.
......
from contextlib import contextmanager
from typing import List, Union
from typing import List, Union, Optional
import torch
......@@ -34,6 +34,9 @@ def forward_backward_no_pipelining(
model: Union[torch.nn.Module, List[torch.nn.Module]],
*,
forward_only: bool,
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
**kwargs,
):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
......@@ -50,6 +53,9 @@ def forward_backward_no_pipelining(
Keyword args:
forward_only:
grad_scaler:
dtype:
disable_autocast
**kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns:
......@@ -75,20 +81,33 @@ def forward_backward_no_pipelining(
cur_micro_batch = get_kth_microbatch(batch, i)
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced)
forward_step_func,
cur_micro_batch,
model,
input_tensor,
losses_reduced,
dtype=dtype,
disable_autocast=disable_autocast,
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
# Run computation for last microbatch out of context handler (want to
# synchronize gradients).
_logger.info("Cooldown")
_logger.debug("Call `forward_step`")
output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced
forward_step_func,
get_kth_microbatch(batch, num_micro_batches - 1),
model,
input_tensor,
losses_reduced,
dtype=dtype,
disable_autocast=disable_autocast,
)
if not forward_only:
_logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
return losses_reduced
......@@ -29,6 +29,8 @@ def _forward_backward_pipelining_with_interleaving(
forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed.
......@@ -54,6 +56,8 @@ def _forward_backward_pipelining_with_interleaving(
tensor_shape: Shape of tensor.
dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
......@@ -132,6 +136,8 @@ def _forward_backward_pipelining_with_interleaving(
model[model_chunk_id],
input_tensor,
losses_reduced,
dtype,
disable_autocast,
)
curr_iters[model_chunk_id] += 1
output_tensors[model_chunk_id].append(output_tensor)
......@@ -158,7 +164,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
return input_tensor_grad
......
......@@ -160,6 +160,8 @@ def forward_backward_pipelining_without_interleaving(
tensor_shape: Optional[Union[List[int], torch.Size]] = None,
decoder_sequence_length: Optional[int] = None,
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages.
......@@ -229,7 +231,15 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced)
output_tensor = forward_step(
forward_step_func,
cur_microbatch,
model,
input_tensor,
losses_reduced,
dtype,
disable_autocast,
)
_logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
......@@ -254,7 +264,13 @@ def forward_backward_pipelining_without_interleaving(
cur_microbatch: torch.Tensor = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
forward_step_func, cur_microbatch, model, input_tensor, losses_reduced
forward_step_func,
cur_microbatch,
model,
input_tensor,
losses_reduced,
dtype,
disable_autocast,
)
if forward_only:
_logger.debug("send fwd")
......@@ -276,7 +292,7 @@ def forward_backward_pipelining_without_interleaving(
input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
if last_iteration:
input_tensor = None
......@@ -298,7 +314,7 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
_logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
......
......@@ -13,7 +13,7 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed
......
from typing import Optional
import warnings
import torch
from torch.cuda.amp import GradScaler
from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func
......@@ -42,7 +44,8 @@ def forward_backward_func_template(
forward_backward_func,
pipeline_model_parallel_size: int,
forward_only: bool,
enable_autocast: bool,
dtype: torch.dtype,
grad_scaler: Optional[GradScaler],
) -> None:
print_separator(f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}")
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
......@@ -92,10 +95,8 @@ def forward_backward_func_template(
tensor_shape[0] = micro_batch_size
update_num_microbatches(0)
dtype = torch.half if enable_autocast else None
with torch.cuda.amp.autocast():
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype)
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype, grad_scaler=grad_scaler)
if not forward_only:
for m in model:
......@@ -119,6 +120,9 @@ if __name__ == "__main__":
batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size
autocast_dtypes = (
[torch.half, torch.bfloat16] if torch.cuda.is_bf16_supported() else [torch.half]
) + [torch.float32]
for forward_only in (True, False):
for name, forward_backward_func in fwd_bwd_functions.items():
if name == "interleaving" and torch.cuda.device_count() <= 2:
......@@ -127,7 +131,10 @@ if __name__ == "__main__":
"while interleaved scheduled pipeline parallel requires >2 gpus."
)
continue
for enable_autocast in (True, False):
for dtype in autocast_dtypes:
if torch.distributed.get_rank() == 0:
_logger.info(f"forward_only: {forward_only}, name: {name}, dtype: {dtype}")
grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None
n_tests += 1
# TODO (mkozuki): Test with data parallel size > 1.
pipeline_model_parallel_size = world_size
......@@ -138,7 +145,8 @@ if __name__ == "__main__":
forward_backward_func,
pipeline_model_parallel_size,
forward_only,
enable_autocast=enable_autocast,
dtype=dtype,
grad_scaler=grad_scaler,
)
except Exception as e:
failures.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment