Unverified Commit 2b7d280b authored by Masaki Kozuki's avatar Masaki Kozuki Committed by GitHub
Browse files

[transformer][pipeline parallel] warn if deallocation is enabled (#1365)

This is cherry-picked for easier comparison with megatron-lm.
parent 77f9d73c
...@@ -335,6 +335,9 @@ def backward_step( ...@@ -335,6 +335,9 @@ def backward_step(
input_tensor: input_tensor:
output_tensor: output_tensor:
output_tensor_grad: output_tensor_grad:
Keyword Arguments:
grad_scaler:
deallocate_pipeline_outputs: Experimental.
Returns: Returns:
input_tensor_grad input_tensor_grad
""" """
......
from typing import List, Union, Optional, Sequence from typing import List, Union, Optional, Sequence
import warnings
import torch import torch
...@@ -70,6 +71,12 @@ def _forward_backward_pipelining_with_interleaving( ...@@ -70,6 +71,12 @@ def _forward_backward_pipelining_with_interleaving(
if not isinstance(model, list): if not isinstance(model, list):
raise RuntimeError("`model` must be a list of `nn.Module`'s'") raise RuntimeError("`model` must be a list of `nn.Module`'s'")
if deallocate_pipeline_outputs:
warnings.warn(
"`deallocate_pipeline_outputs` is experimental and subject to change. "
"This option is not recommended."
)
num_model_chunks: int = len(model) num_model_chunks: int = len(model)
input_tensors: List[List[Union[None, torch.Tensor]]] = [ input_tensors: List[List[Union[None, torch.Tensor]]] = [
[] for _ in range(num_model_chunks) [] for _ in range(num_model_chunks)
......
from typing import Union, List, Optional, Sequence from typing import Union, List, Optional, Sequence
import warnings
import torch import torch
...@@ -196,6 +197,12 @@ def forward_backward_pipelining_without_interleaving( ...@@ -196,6 +197,12 @@ def forward_backward_pipelining_without_interleaving(
""" """
# timers = get_timers() # timers = get_timers()
if deallocate_pipeline_outputs:
warnings.warn(
"`deallocate_pipeline_outputs` is experimental and subject to change. "
"This option is not recommended."
)
model: List[torch.nn.Module] = listify_model(model) model: List[torch.nn.Module] = listify_model(model)
if len(model) != 1: if len(model) != 1:
msg = f"`model` is expected be a `nn.Module`, but {type(model)}" msg = f"`model` is expected be a `nn.Module`, but {type(model)}"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment