Unverified Commit 2a4ab177 authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

Grad scaler (#1277)

* add keyword argument of `grad_scaler`

* update test

* pass dtype to fwd_step_func

* add log

* calc loss in autocast as per https://pytorch.org/docs/stable/amp.html#autocasting

* add keyword argument of `grad_scaler`

* update test

* pass dtype to fwd_step_func

* add log

* calc loss in autocast as per https://pytorch.org/docs/stable/amp.html#autocasting

* option to turn off autocast inside forward_step function

As there's some users who activate `autocast` outside fwd/bwd functions.

* add missing arg of disable_autocast

* reorder args of no pipeline
parent 45cd1001
...@@ -10,6 +10,10 @@ from apex.transformer.pipeline_parallel.utils import listify_model ...@@ -10,6 +10,10 @@ from apex.transformer.pipeline_parallel.utils import listify_model
from apex.transformer.pipeline_parallel.utils import unwrap_model from apex.transformer.pipeline_parallel.utils import unwrap_model
from apex.transformer.pipeline_parallel.utils import get_model_type from apex.transformer.pipeline_parallel.utils import get_model_type
from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes from apex.transformer.tensor_parallel.layers import set_defaults_if_not_set_tensor_model_parallel_attributes
from apex.transformer.log_util import get_transformer_logger
_logger = get_transformer_logger(__name__)
Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]] Batch = Union[torch.Tensor, List[torch.Tensor], Tuple[torch.Tensor, ...]]
...@@ -147,6 +151,8 @@ def forward_step( ...@@ -147,6 +151,8 @@ def forward_step(
model: torch.nn.Module, model: torch.nn.Module,
input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]], input_tensor: Optional[Union[torch.Tensor, List[torch.Tensor]]],
losses_reduced: List[torch.Tensor], losses_reduced: List[torch.Tensor],
dtype: torch.dtype,
disable_autocast: bool = False,
) -> Union[torch.Tensor, Sequence[torch.Tensor]]: ) -> Union[torch.Tensor, Sequence[torch.Tensor]]:
"""Forward step for passed-in model. """Forward step for passed-in model.
...@@ -161,6 +167,8 @@ def forward_step( ...@@ -161,6 +167,8 @@ def forward_step(
model: unwrappable model model: unwrappable model
input_tensor: input_tensor:
losses_reduced: losses_reduced:
dtype:
disable_autocast:
Returns: Returns:
output_tensor output_tensor
...@@ -177,12 +185,16 @@ def forward_step( ...@@ -177,12 +185,16 @@ def forward_step(
input_tensor = [input_tensor] input_tensor = [input_tensor]
unwrapped_model.set_input_tensor(input_tensor) unwrapped_model.set_input_tensor(input_tensor)
output_tensor, loss_func = forward_step_func(batch, model) with torch.cuda.amp.autocast(
if parallel_state.is_pipeline_last_stage(): enabled=not disable_autocast and dtype in (torch.half, torch.bfloat16),
output_tensor = loss_func(output_tensor) dtype=dtype,
loss, loss_reduced = output_tensor ):
output_tensor = loss / get_num_microbatches() output_tensor, loss_func = forward_step_func(batch, model)
losses_reduced.append(loss_reduced) if parallel_state.is_pipeline_last_stage():
output_tensor = loss_func(output_tensor)
loss, loss_reduced = output_tensor
output_tensor = loss / get_num_microbatches()
losses_reduced.append(loss_reduced)
# timers("forward-compute").stop() # timers("forward-compute").stop()
# If T5 model (or other model with encoder and decoder) # If T5 model (or other model with encoder and decoder)
...@@ -200,6 +212,8 @@ def backward_step( ...@@ -200,6 +212,8 @@ def backward_step(
output_tensor: torch.Tensor, output_tensor: torch.Tensor,
output_tensor_grad: Optional[torch.Tensor], output_tensor_grad: Optional[torch.Tensor],
model_type: ModelType, model_type: ModelType,
*,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]: ) -> Union[None, torch.Tensor, Sequence[torch.Tensor]]:
"""Backward step through passed-in output tensor. """Backward step through passed-in output tensor.
...@@ -234,6 +248,8 @@ def backward_step( ...@@ -234,6 +248,8 @@ def backward_step(
output_tensor_grad = [output_tensor_grad] output_tensor_grad = [output_tensor_grad]
# Backward pass. # Backward pass.
if grad_scaler is not None and output_tensor_grad[0] is None:
output_tensor[0] = grad_scaler.scale(output_tensor[0])
torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0]) torch.autograd.backward(output_tensor[0], grad_tensors=output_tensor_grad[0])
# Collect the grad of the input_tensor. # Collect the grad of the input_tensor.
......
from contextlib import contextmanager from contextlib import contextmanager
from typing import List, Union from typing import List, Union, Optional
import torch import torch
...@@ -34,6 +34,9 @@ def forward_backward_no_pipelining( ...@@ -34,6 +34,9 @@ def forward_backward_no_pipelining(
model: Union[torch.nn.Module, List[torch.nn.Module]], model: Union[torch.nn.Module, List[torch.nn.Module]],
*, *,
forward_only: bool, forward_only: bool,
dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
**kwargs, **kwargs,
): ):
"""Run forward and backward passes with no pipeline parallelism (no inter-stage communication). """Run forward and backward passes with no pipeline parallelism (no inter-stage communication).
...@@ -50,6 +53,9 @@ def forward_backward_no_pipelining( ...@@ -50,6 +53,9 @@ def forward_backward_no_pipelining(
Keyword args: Keyword args:
forward_only: forward_only:
grad_scaler:
dtype:
disable_autocast
**kwargs: Added to handle `tensor_shape` which has no effect on this function. **kwargs: Added to handle `tensor_shape` which has no effect on this function.
Returns: Returns:
...@@ -75,20 +81,33 @@ def forward_backward_no_pipelining( ...@@ -75,20 +81,33 @@ def forward_backward_no_pipelining(
cur_micro_batch = get_kth_microbatch(batch, i) cur_micro_batch = get_kth_microbatch(batch, i)
_logger.debug("Call `forward_step`") _logger.debug("Call `forward_step`")
output_tensor = forward_step( output_tensor = forward_step(
forward_step_func, cur_micro_batch, model, input_tensor, losses_reduced) forward_step_func,
cur_micro_batch,
model,
input_tensor,
losses_reduced,
dtype=dtype,
disable_autocast=disable_autocast,
)
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`") _logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type) backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
# Run computation for last microbatch out of context handler (want to # Run computation for last microbatch out of context handler (want to
# synchronize gradients). # synchronize gradients).
_logger.info("Cooldown") _logger.info("Cooldown")
_logger.debug("Call `forward_step`") _logger.debug("Call `forward_step`")
output_tensor = forward_step( output_tensor = forward_step(
forward_step_func, get_kth_microbatch(batch, num_micro_batches - 1), model, input_tensor, losses_reduced forward_step_func,
get_kth_microbatch(batch, num_micro_batches - 1),
model,
input_tensor,
losses_reduced,
dtype=dtype,
disable_autocast=disable_autocast,
) )
if not forward_only: if not forward_only:
_logger.debug("Call `backward_step`") _logger.debug("Call `backward_step`")
backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type) backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
return losses_reduced return losses_reduced
...@@ -29,6 +29,8 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -29,6 +29,8 @@ def _forward_backward_pipelining_with_interleaving(
forward_only: bool, forward_only: bool,
tensor_shape: Optional[Union[List[int], torch.Size]] = None, tensor_shape: Optional[Union[List[int], torch.Size]] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run interleaved 1F1B schedule with communication between pipeline stages as needed. """Run interleaved 1F1B schedule with communication between pipeline stages as needed.
...@@ -54,6 +56,8 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -54,6 +56,8 @@ def _forward_backward_pipelining_with_interleaving(
tensor_shape: Shape of tensor. tensor_shape: Shape of tensor.
dtype: dtype used in p2p communication. If ``None`` (default value), dtype: dtype used in p2p communication. If ``None`` (default value),
torch.float32 will be used even if ``autocast`` is enabled. torch.float32 will be used even if ``autocast`` is enabled.
grad_scaler:
disable_autocast:
Returns: Returns:
a list of loss `torch.Tensor`s if the last stage, empty list otherwise. a list of loss `torch.Tensor`s if the last stage, empty list otherwise.
...@@ -132,6 +136,8 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -132,6 +136,8 @@ def _forward_backward_pipelining_with_interleaving(
model[model_chunk_id], model[model_chunk_id],
input_tensor, input_tensor,
losses_reduced, losses_reduced,
dtype,
disable_autocast,
) )
curr_iters[model_chunk_id] += 1 curr_iters[model_chunk_id] += 1
output_tensors[model_chunk_id].append(output_tensor) output_tensors[model_chunk_id].append(output_tensor)
...@@ -158,7 +164,7 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -158,7 +164,7 @@ def _forward_backward_pipelining_with_interleaving(
input_tensor = input_tensors[model_chunk_id].pop(0) input_tensor = input_tensors[model_chunk_id].pop(0)
output_tensor = output_tensors[model_chunk_id].pop(0) output_tensor = output_tensors[model_chunk_id].pop(0)
output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0) output_tensor_grad = output_tensor_grads[model_chunk_id].pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
return input_tensor_grad return input_tensor_grad
......
...@@ -160,6 +160,8 @@ def forward_backward_pipelining_without_interleaving( ...@@ -160,6 +160,8 @@ def forward_backward_pipelining_without_interleaving(
tensor_shape: Optional[Union[List[int], torch.Size]] = None, tensor_shape: Optional[Union[List[int], torch.Size]] = None,
decoder_sequence_length: Optional[int] = None, decoder_sequence_length: Optional[int] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
grad_scaler: Optional[torch.cuda.amp.GradScaler] = None,
disable_autocast: bool = False,
) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]: ) -> List[Union[torch.Tensor, Sequence[torch.Tensor]]]:
"""Run non-interleaved 1F1B schedule, with communication between pipeline stages. """Run non-interleaved 1F1B schedule, with communication between pipeline stages.
...@@ -229,7 +231,15 @@ def forward_backward_pipelining_without_interleaving( ...@@ -229,7 +231,15 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug("receive fwd") _logger.debug("receive fwd")
input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype) input_tensor = recv_forward(tensor_shapes=recv_tensor_shapes, dtype=dtype)
cur_microbatch = get_kth_microbatch(batch, i) cur_microbatch = get_kth_microbatch(batch, i)
output_tensor = forward_step(forward_step_func, cur_microbatch, model, input_tensor, losses_reduced) output_tensor = forward_step(
forward_step_func,
cur_microbatch,
model,
input_tensor,
losses_reduced,
dtype,
disable_autocast,
)
_logger.debug("send fwd") _logger.debug("send fwd")
send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype) send_forward(output_tensor, tensor_shapes=send_tensor_shapes, dtype=dtype)
...@@ -254,7 +264,13 @@ def forward_backward_pipelining_without_interleaving( ...@@ -254,7 +264,13 @@ def forward_backward_pipelining_without_interleaving(
cur_microbatch: torch.Tensor = get_kth_microbatch(batch, i + num_warmup_microbatches) cur_microbatch: torch.Tensor = get_kth_microbatch(batch, i + num_warmup_microbatches)
output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step( output_tensor: Union[torch.Tensor, Sequence[torch.Tensor]] = forward_step(
forward_step_func, cur_microbatch, model, input_tensor, losses_reduced forward_step_func,
cur_microbatch,
model,
input_tensor,
losses_reduced,
dtype,
disable_autocast,
) )
if forward_only: if forward_only:
_logger.debug("send fwd") _logger.debug("send fwd")
...@@ -276,7 +292,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -276,7 +292,7 @@ def forward_backward_pipelining_without_interleaving(
input_tensor = input_tensors.pop(0) input_tensor = input_tensors.pop(0)
output_tensor = output_tensors.pop(0) output_tensor = output_tensors.pop(0)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
if last_iteration: if last_iteration:
input_tensor = None input_tensor = None
...@@ -298,7 +314,7 @@ def forward_backward_pipelining_without_interleaving( ...@@ -298,7 +314,7 @@ def forward_backward_pipelining_without_interleaving(
_logger.debug("receive bwd") _logger.debug("receive bwd")
output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype) output_tensor_grad = recv_backward(tensor_shapes=send_tensor_shapes, dtype=dtype)
input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type) input_tensor_grad = backward_step(input_tensor, output_tensor, output_tensor_grad, model_type=model_type, grad_scaler=grad_scaler)
_logger.debug("send bwd") _logger.debug("send bwd")
send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype) send_backward(input_tensor_grad, tensor_shapes=recv_tensor_shapes, dtype=dtype)
......
...@@ -13,7 +13,7 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model ...@@ -13,7 +13,7 @@ from apex.transformer.pipeline_parallel.schedules.common import build_model
from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization from apex.transformer.pipeline_parallel.schedules.common import _get_params_for_weight_decay_optimization
from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving from apex.transformer.pipeline_parallel.schedules.fwd_bwd_pipelining_without_interleaving import forward_backward_pipelining_without_interleaving
from apex.transformer.testing.standalone_gpt import gpt_model_provider from apex.transformer.testing.standalone_gpt import gpt_model_provider
from apex.transformer.testing import global_vars from apex.transformer.testing import global_vars
from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE from apex.transformer.testing.commons import TEST_SUCCESS_MESSAGE
from apex.transformer.testing.commons import initialize_distributed from apex.transformer.testing.commons import initialize_distributed
......
from typing import Optional
import warnings import warnings
import torch import torch
from torch.cuda.amp import GradScaler
from apex.transformer import parallel_state from apex.transformer import parallel_state
from apex.transformer.pipeline_parallel import get_forward_backward_func from apex.transformer.pipeline_parallel import get_forward_backward_func
...@@ -42,7 +44,8 @@ def forward_backward_func_template( ...@@ -42,7 +44,8 @@ def forward_backward_func_template(
forward_backward_func, forward_backward_func,
pipeline_model_parallel_size: int, pipeline_model_parallel_size: int,
forward_only: bool, forward_only: bool,
enable_autocast: bool, dtype: torch.dtype,
grad_scaler: Optional[GradScaler],
) -> None: ) -> None:
print_separator(f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}") print_separator(f"name: {name}, pipeline model parallel size: {pipeline_model_parallel_size}")
virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None virtual_pipeline_model_parallel_size = 2 if name == "interleaving" else None
...@@ -92,10 +95,8 @@ def forward_backward_func_template( ...@@ -92,10 +95,8 @@ def forward_backward_func_template(
tensor_shape[0] = micro_batch_size tensor_shape[0] = micro_batch_size
update_num_microbatches(0) update_num_microbatches(0)
dtype = torch.half if enable_autocast else None forward_backward_func(
with torch.cuda.amp.autocast(): fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype, grad_scaler=grad_scaler)
forward_backward_func(
fwd_step_func, batch, model, forward_only=forward_only, tensor_shape=tensor_shape, dtype=dtype)
if not forward_only: if not forward_only:
for m in model: for m in model:
...@@ -119,6 +120,9 @@ if __name__ == "__main__": ...@@ -119,6 +120,9 @@ if __name__ == "__main__":
batch_size = args.global_batch_size batch_size = args.global_batch_size
micro_batch_size = args.micro_batch_size micro_batch_size = args.micro_batch_size
autocast_dtypes = (
[torch.half, torch.bfloat16] if torch.cuda.is_bf16_supported() else [torch.half]
) + [torch.float32]
for forward_only in (True, False): for forward_only in (True, False):
for name, forward_backward_func in fwd_bwd_functions.items(): for name, forward_backward_func in fwd_bwd_functions.items():
if name == "interleaving" and torch.cuda.device_count() <= 2: if name == "interleaving" and torch.cuda.device_count() <= 2:
...@@ -127,7 +131,10 @@ if __name__ == "__main__": ...@@ -127,7 +131,10 @@ if __name__ == "__main__":
"while interleaved scheduled pipeline parallel requires >2 gpus." "while interleaved scheduled pipeline parallel requires >2 gpus."
) )
continue continue
for enable_autocast in (True, False): for dtype in autocast_dtypes:
if torch.distributed.get_rank() == 0:
_logger.info(f"forward_only: {forward_only}, name: {name}, dtype: {dtype}")
grad_scaler = torch.cuda.amp.GradScaler(init_scale=4.0) if dtype == torch.half else None
n_tests += 1 n_tests += 1
# TODO (mkozuki): Test with data parallel size > 1. # TODO (mkozuki): Test with data parallel size > 1.
pipeline_model_parallel_size = world_size pipeline_model_parallel_size = world_size
...@@ -138,7 +145,8 @@ if __name__ == "__main__": ...@@ -138,7 +145,8 @@ if __name__ == "__main__":
forward_backward_func, forward_backward_func,
pipeline_model_parallel_size, pipeline_model_parallel_size,
forward_only, forward_only,
enable_autocast=enable_autocast, dtype=dtype,
grad_scaler=grad_scaler,
) )
except Exception as e: except Exception as e:
failures.append( failures.append(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment